Capsule Networks
Résumé
Les Capsule Networks (réseaux de capsules) constituent une architecture d’apprentissage profond introduite par Geoffrey Hinton et son équipe en 2017. Contrairement aux réseaux de neurones convolutifs classiques (CNN) qui produisent des activations scalaires, les Capsule Networks manipulent des vecteurs d’activation appelés capsules. Chaque capsule encode non seulement la présence d’une entité visuelle, mais aussi ses propriétés spatiales — orientation, échelle, position et déformation — au sein d’un seul vecteur continu.
Cette innovation conceptuelle résout une faiblesse fondamentale des CNN : leur incapacité à modéliser explicitement les relations spatiales hiérarchiques entre les caractéristiques détectées. Grâce au mécanisme de routing par accord (dynamic routing by agreement), les Capsule Networks établissent des connexions dynamiques entre les capsules de niveaux différents, permettant au réseau d’apprendre une représentation hiérarchique robuste et cohérente de la scène visuelle.
Les Capsule Networks offrent ainsi une robustesse exceptionnelle aux transformations géométriques (rotations, translations, changements d’échelle) sans nécessiter l’augmentation massive de données d’entraînement requise par les approches conventionnelles.
Principe mathématique
Une capsule est un groupe de neurones dont la sortie est un vecteur plutôt qu’un scalaire. La longueur (norme) de ce vecteur représente la probabilité de présence d’une entité particulière, tandis que son orientation encode les paramètres de posture (pose parameters) de cette entité.
La fonction Squash
Pour s’assurer que la norme du vecteur de sortie soit comprise entre 0 et 1 (interprétable comme une probabilité), on applique la fonction de squash :
v_j = (||s_j||² / (1 + ||s_j||²)) · (s_j / ||s_j||)
où s_j est la somme pondérée des prédictions d’entrée pour la capsule j, et v_j est le vecteur de sortie après squash. Cette fonction possède deux propriétés essentielles :
- Pour les vecteurs de petite norme, le résultat est quasi-nul — la capsule est inactive.
- Pour les vecteurs de grande norme, le résultat tend vers le vecteur unitaire dans la même direction — la capsule est pleinement active.
Cette non-linéarité préserve la direction du vecteur (qui contient l’information de pose) tout en normalisant sa longueur (qui encode la probabilité de présence).
Prédiction et couplage
Chaque capsule de niveau inférieur i génère une prédiction pour chaque capsule de niveau supérieur j :
û_j|i = W_ij · u_i
où u_i est le vecteur de sortie de la capsule i et W_ij est une matrice de poids apprise qui transforme la pose de la capsule i dans le référentiel de la capsule j.
Routing par accord (Dynamic Routing)
Le mécanisme central des Capsule Networks est le routing par accord. L’idée est élégante : une capsule de niveau supérieur ne devrait recevoir d’information que des capsules de niveau inférieur dont les prédictions sont cohérentes entre elles.
Le processus itératif fonctionne ainsi :
- Initialisation : les coefficients de couplage b_ij sont initialisés à zéro.
- Calcul des poids : c_ij = softmax(b_ij) — les coefficients sont normalisés pour chaque capsule i à travers toutes les capsules j.
- Agrégation : s_j = Σ_i c_ij · û_j|i — chaque capsule reçoit une combinaison pondérée des prédictions.
- Squash : v_j = squash(s_j) — application de la fonction de squash.
- Mise à jour des accords : b_ij ← b_ij + û_j|i · v_j — on augmente le couplage lorsque la prédiction et la sortie finale sont alignées.
Ce processus est répété un nombre fixe d’itérations (typiquement 3). Résultat : les capsules dont les prédictions sont cohérentes avec la sortie réelle voient leur influence augmenter, tandis que les prédictions discordantes sont progressivement ignorées. C’est un mécanisme d’attention émergente, entièrement différentiable et appris de manière end-to-end.
Intuition géométrique
Pour comprendre pourquoi les Capsule Networks représentent une avancée conceptuelle majeure, il faut examiner la faiblesse intrinsèque des réseaux convolutifs classiques.
Un CNN traditionnel applique des opérations de pooling (max-pooling ou average-pooling) qui réduisent la résolution spatiale de l’information. Cette réduction permet d’atteindre une invariance : un objet détecté dans un coin de l’image ou au centre produira activement la même réponse. Cependant, cette invariance a un coût énorme — la perte des relations spatiales entre les caractéristiques détectées.
Imaginez un visage humain. Un CNN classique apprendra à détecter les yeux, le nez et la bouche comme des entités indépendantes. Si toutes ces caractéristiques sont présentes dans l’image, le réseau conclura qu’il s’agit d’un visage. Mais considérons l’absurde suivant : un visage où les yeux sont placés en bas et la bouche en haut. Un CNN classique reconnaîtra probablement toujours un visage, car les éléments individuels sont présents ! Il ne vérifie pas l’arrangement spatial.
Les Capsule Networks résolvent ce problème fondamental de manière élégante. Au lieu de produire un simple scalaire indiquant « un œil est présent ici », chaque capsule produit un vecteur de posture qui encode la position, l’orientation, l’échelle et la déformation de l’entité détectée. Les capsules de niveau supérieur reçoivent ces vecteurs de posture et, grâce au routing par accord, vérifient si les prédictions des capsules inférieures sont spatialement cohérentes.
C’est la différence philosophique entre :
- Reconnaître un visage par ses éléments individuels — approche CNN classique (présence des pièces, sans vérification de l’assemblage).
- Reconnaître un visage par l’arrangement spatial de ses éléments — approche Capsule (présence ET cohérence de la configuration spatiale).
Cette approche est analogue à la perception humaine. Notre système visuel ne reconnaît pas un objet simplement par la liste de ses composants — il évalue la configuration de ces composants. Un triangle reste un triangle même si on le tourne, car notre cerveau encode les relations angulaires entre les sommets, pas les positions absolues de chaque point. Les Capsule Networks tentent de reproduire ce principe fondamental de perception.
Implémentation Python avec PyTorch
Voici une implémentation complète des Capsule Networks, inspirée de l’architecture originale de Sabour, Frosst et Hinton (2017), adaptée pour l’entraînement sur MNIST.
Fonction Squash
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
class Squash(nn.Module):
"""Fonction de squash pour les Capsule Networks.
v_j = (||s_j||² / (1 + ||s_j||²)) * (s_j / ||s_j||)
Cette fonction normalise la longueur du vecteur entre 0 et 1
tout en préservant sa direction.
"""
def __init__(self, dim=-1, epsilon=1e-7):
super().__init__()
self.dim = dim
self.epsilon = epsilon
def forward(self, s):
sq_norm = (s ** 2).sum(dim=self.dim, keepdim=True)
scale = sq_norm / (1.0 + sq_norm)
norm = torch.sqrt(sq_norm + self.epsilon)
v = scale * (s / norm)
return v
Couche de Routing Dynamique
class DynamicRouting(nn.Module):
"""Mécanisme de routing par accord (dynamic routing by agreement).
Implémente l'algorithme de routing itératif où les coefficients
de couplage b_ij sont mis à jour par : b_ij ← b_ij + û_j|i · v_j
"""
def __init__(self, num_iterations=3):
super().__init__()
self.num_iterations = num_iterations
def forward(self, u_hat, num_input_caps, num_output_caps):
"""
Args:
u_hat: tenseur des prédictions [batch, num_input, num_output, dim_out]
num_input_caps: nombre de capsules d'entrée
num_output_caps: nombre de capsules de sortie
Returns:
v: vecteurs de sortie des capsules [batch, num_output, dim_out]
"""
batch_size = u_hat.size(0)
dim_out = u_hat.size(-1)
# Initialisation des logits de couplage
b = torch.zeros(batch_size, num_input_caps, num_output_caps,
device=u_hat.device)
for iteration in range(self.num_iterations):
# Calcul des coefficients de couplage par softmax
c = F.softmax(b, dim=2)
# Agrégation pondérée des prédictions
c_expand = c.unsqueeze(-1)
s = (c_expand * u_hat).sum(dim=1)
# Application de la fonction squash
sq_norm = (s ** 2).sum(dim=-1, keepdim=True)
scale = sq_norm / (1.0 + sq_norm)
norm = torch.sqrt(sq_norm + 1e-7)
v = scale * (s / norm)
# Mise à jour des accords (sauf à la dernière itération)
if iteration < self.num_iterations - 1:
b += (u_hat * v.unsqueeze(1)).sum(dim=-1)
return v
Architecture CapsNet complète
class CapsNet(nn.Module):
"""Capsule Network complète pour la classification MNIST.
Architecture inspirée de Sabour, Frosst & Hinton (2017).
Comprend les capsules primaires, le routing dynamique et
un décodeur de reconstruction pour la régularisation.
"""
def __init__(self, num_primary_caps=1152, num_digit_caps=10,
capsule_dim=16, routing_iterations=3, recon_hidden=512):
super().__init__()
self.num_primary_caps = num_primary_caps
self.num_digit_caps = num_digit_caps
self.capsule_dim = capsule_dim
# Couche convolutive initiale (extraction de features basiques)
self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
self.relu = nn.ReLU()
# Couche PrimaryCaps — transforme les features scalaires
# en vecteurs de capsules (32 canaux × 6×6 positions = 1152 capsules)
self.primary_caps = nn.Conv2d(256, 32 * capsule_dim,
kernel_size=9, stride=2)
self.squash = Squash()
self.routing = DynamicRouting(num_iterations=routing_iterations)
# Matrices de poids pour les prédictions
self.W = nn.Parameter(torch.randn(1, num_primary_caps,
num_digit_caps, capsule_dim,
capsule_dim) * 0.1)
# Décodeur de reconstruction pour la régularisation
self.decoder = nn.Sequential(
nn.Linear(num_digit_caps * capsule_dim, recon_hidden),
nn.ReLU(inplace=True),
nn.Linear(recon_hidden, recon_hidden * 2),
nn.ReLU(inplace=True),
nn.Linear(recon_hidden * 2, 28 * 28),
nn.Sigmoid()
)
def forward(self, x):
# Extraction de features convolutives
x = self.relu(self.conv1(x)) # [batch, 256, 20, 20]
# Génération des capsules primaires
pc = self.primary_caps(x) # [batch, 32*16, 6, 6]
batch_size = pc.size(0)
# Reshape en capsules individuelles
pc = pc.view(batch_size, -1, self.capsule_dim) # [batch, 1152, 16]
pc = self.squash(pc)
# Prédiction : û = W · u
u = pc.unsqueeze(2).unsqueeze(3) # [batch, 1152, 1, 1, 16]
u_hat = torch.matmul(self.W, u).squeeze(3) # [batch, 1152, 10, 16]
# Routing dynamique
v = self.routing(u_hat, self.num_primary_caps, self.num_digit_caps)
# Classification basée sur la norme des capsules
classes = (v ** 2).sum(dim=-1) # [batch, 10]
# Reconstruction
max_idx = classes.argmax(dim=1)
mask = torch.zeros_like(v)
mask[range(batch_size), max_idx, :] = v[range(batch_size), max_idx, :]
reconstruction = self.decoder(mask.view(batch_size, -1))
return classes, v, reconstruction
Fonction de perte marginale et entraînement
def margin_loss(classes, targets, lambda_loss, m_plus=0.9, m_minus=0.1):
"""Fonction de perte marginale pour les Capsule Networks.
Encourage la capsule cible à avoir une norme >= m_plus
et les capsules non-cibles à avoir une norme <= m_minus.
"""
left = F.relu(m_plus - classes).pow(2)
right = F.relu(classes - m_minus).pow(2)
mask = F.one_hot(targets, num_classes=10).float()
loss = mask * left + lambda_loss * (1 - mask) * right
return loss.sum(dim=1).mean()
def train_epoch(model, dataloader, optimizer, device, lambda_loss=0.5,
recon_weight=0.0005):
"""Entraînement d'une époque sur MNIST."""
model.train()
total_loss = 0.0
correct = 0
total = 0
for images, targets in dataloader:
images, targets = images.to(device), targets.to(device)
optimizer.zero_grad()
classes, v, reconstruction = model(images)
# Perte marginale
m_loss = margin_loss(classes, targets, lambda_loss)
# Perte de reconstruction (régularisation)
recon_loss = F.mse_loss(reconstruction,
images.view(images.size(0), -1))
# Perte totale
loss = m_loss + recon_weight * recon_loss
loss.backward()
optimizer.step()
total_loss += loss.item()
predictions = classes.argmax(dim=1)
correct += (predictions == targets).sum().item()
total += targets.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct / total * 100
return avg_loss, accuracy
def evaluate_robustness(model, dataloader, device, angle_degrees=45):
"""Évalue la robustesse du modèle aux rotations.
Les Capsule Networks devraient maintenir de bonnes performances
même avec des images tournées, contrairement aux CNN classiques.
"""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, targets in dataloader:
images, targets = images.to(device), targets.to(device)
# Rotation des images
rotated = torchvision.transforms.functional.rotate(
images, angle_degrees
)
classes, _, _ = model(rotated)
predictions = classes.argmax(dim=1)
correct += (predictions == targets).sum().item()
total += targets.size(0)
return correct / total * 100
Hyperparamètres clés
| Hyperparamètre | Valeur typique | Rôle |
|---|---|---|
num_primary_caps |
1152 | Nombre de capsules primaires (32 × 6 × 6 sur MNIST) |
num_digit_caps |
10 | Nombre de capsules de classe (une par chiffre sur MNIST) |
capsule_dim |
8 ou 16 | Dimension du vecteur de chaque capsule |
routing_iterations |
3 | Nombre d’itérations du routing par accord |
lambda_loss |
0.5 | Pondération de la pénalité pour les fausses classes |
recon_weight |
0.0005 | Poids de la perte de reconstruction |
m_plus |
0.9 | Norme cible minimale pour la classe correcte |
m_minus |
0.1 | Norme cible maximale pour les classes incorrectes |
Influence du nombre d’itérations de routing
Le nombre d’itérations de routing représente un compromis fondamental entre précision et coût computationnel. Avec une seule itération, le routing équivaut à un softmax standard — aucune information de feedback entre les capsules. À partir de trois itérations, le mécanisme converge généralement vers des couplages stables. Au-delà de cinq itérations, les gains de précision deviennent marginaux tandis que le temps de calcul augmente linéairement.
Dimension des capsules
Une dimension typique de 8 à 16 est recommandée pour les datasets simples comme MNIST. Pour des tâches visuelles plus complexes, des dimensions supérieures (32 à 64) peuvent être nécessaires pour encoder suffisamment de paramètres de posture. Cependant, augmenter la dimension multiplie également la taille des matrices de poids W_ij, ce qui impacte directement la mémoire GPU.
Avantages et limitations
Avantages
Préservation des relations spatiales : C’est l’avantage fondamental. Les Capsule Networks encodent explicitement les relations hiérarchiques entre les entités visuelles. Un visage avec des yeux mal placés ne sera pas reconnu comme un visage, ce qui correspond beaucoup mieux à la perception humaine.
Robustesse aux transformations : Grâce aux vecteurs de posture, les Capsule Networks sont naturellement robustes aux rotations, translations et changements d’échelle. Contrairement aux CNN qui nécessitent une augmentation massive de données (data augmentation) pour apprendre cette invariance, les Capsule Networks la généralisent grâce à leur représentation vectorielle.
Moins de données d’entraînement : L’architecture intrinsèquement robuste aux transformations nécessite moins d’exemples d’entraînement pour atteindre des performances comparables aux CNN, particulièrement sur des tâches où les transformations géométriques sont fréquentes.
Reconstruction explicite : Le décodeur de reconstruction force les capsules à encoder des informations sémantiquement riches. Les visualisations des reconstructions montrent que les capsules apprennent effectivement des représentations significatives des entités visuelles.
Routing adaptatif : Contrairement aux connexions statiques des CNN, le routing dynamique adapte les connexions entre les couches en fonction du contenu de l’entrée. Chaque image reçoit un routage personnalisé.
Limitations
Coût computationnel élevé : Le routing par ajout itératif est significativement plus coûteux qu’une convolution standard. Chaque itération nécessite des opérations de matrice sur l’ensemble des capsules, ce qui rend l’entraînement et l’inférence considérablement plus lents.
Difficulté de mise à l’échelle : Les Capsule Networks montrent d’excellents résultats sur MNIST mais peinent à surpasser les CNN de pointe sur des datasets complexes comme CIFAR-10 et ImageNet. La matrice de poids W_ij croît quadratiquement avec le nombre de capsules, rendant l’application à grande échelle quasi impraticable avec l’architecture originale.
Convergence lente : Le routing dynamique introduit des non-linéarités complexes qui rendent l’optimisation plus difficile. L’entraînement nécessite souvent plus d’époques et une initialisation soigneuse des poids.
Sensibilité aux hyperparamètres : Le nombre d’itérations de routing, la dimension des capsules et les marges m+ / m- sont des hyperparamètres sensibles qui nécessitent un réglage minutieux pour de bonnes performances.
Recherche encore émergente : L’architecture originale de 2017 a depuis été améliorée (Matrix Capsules, EM Routing), mais aucune variante n’a atteint le niveau de dominance des CNN dans le domaine de la vision par ordinateur.
4 cas d’usage concrets
1. Détection médicale avec relations anatomiques
En imagerie médicale (IRM, radiographie), la position relative des structures anatomiques est cruciale pour le diagnostic. Une tumeur située à côté d’un organe spécifique n’a pas la même signification qu’une tumeur de même apparence située ailleurs. Les Capsule Networks, en encodant explicitement les relations spatiales, offrent un avantage naturel pour ces tâches où la configuration spatiale porte une information diagnostique essentielle.
2. Reconnaissance d’objets avec variations de point de vue
Dans les applications de réalité augmentée ou de robotique mobile, les objets sont observés sous des angles variés et changeants. Les Capsule Networks, grâce à leurs vecteurs de posture, peuvent reconnaître un même objet sous différentes perspectives sans nécessiter un entraînement exhaustif sur chaque angle possible. Cette capacité de généralisation viewpoint est particulièrement précieuse dans les environnements dynamiques et imprévisibles.
3. Détection de fraude et analyse de graphes
Le concept de routing par accord s’applique au-delà de la vision par ordinateur. Dans la détection de fraude financière, on peut modéliser chaque transaction comme une capsule et utiliser un mécanisme similaire de routing pour identifier des schémas suspects basés sur la cohérence entre les entités connectées. Les transactions frauduleuses forment des configurations qui divergent des patterns normaux détectés par le routing.
4. Analyse de scènes complexes avec occlusion
Dans les scènes visuelles où certains objets sont partiellement cachés (occlusion), les Capsule Networks peuvent inférer la présence et la pose d’objets occlus grâce aux relations spatiales apprises entre les capsules. Si les capsules détectent un torse, des bras et une tête partiellement visible, le routing par accord peut renforcer l’hypothèse d’une personne complète même si certaines parties sont masquées. Cette capacité d’inférence contextuelle est difficile à obtenir avec des CNN standards.
Voir aussi
- Quoi ? Où ? Quand ? Maîtriser la Programmation Python : Guide Complet pour Débutants
- Maîtrisez le Tri par Paquets en Python : Guide Complet pour Débutants

