Elastic Weight Consolidation (EWC) — Guide Complet
Résumé
L’Elastic Weight Consolidation (EWC), ou consolidation élastique des poids, est une méthode de régularisation conçue pour résoudre le problème de l’oubli catastrophique (catastrophic forgetting) dans les réseaux de neurones artificiels. Proposée par Kirkpatrick et al. en 2017, cette technique permet à un modèle d’apprendre séquentiellement plusieurs tâches sans perdre les connaissances acquises précédemment. Contrairement à l’entraînement classique qui écrase les poids du modèle à chaque nouvelle tâche, l’EWC identifie les paramètres essentiels aux tâches précédentes et pénalise leur modification excessive. Cette « consolidation élastique » s’inspire directement de la neurobiologie : dans le cerveau humain, les synapses importantes sont protégées par des mécanismes de consolidation synaptique, permettant d’apprendre continuellement tout en préservant les souvenirs anciens. L’EWC reproduit ce comportement en estimant l’importance de chaque poids via la matrice d’information de Fisher diagonale et en ajoutant un terme de régularisation quadratique à la fonction de perte. Le résultat ? Un modèle capable d’accumuler des connaissances de manière cumulative, se rapprochant ainsi de l’apprentissage continu observé chez les êtres vivants.
Principe Mathématique
Le fondement théorique de l’EWC repose sur un cadre bayésien rigoureux. Après avoir appris la tâche A, on souhaite apprendre la tâche B sans détruire les connaissances acquises. Dans l’approche bayésienne, après l’entraînement sur la tâche A, la distribution a posteriori sur les poids du modèle est notée $p(\theta | D_A, t=A)$. Idéalement, l’entraînement sur la tâche B devrait maximiser la vraisemblance marginale :
$$\log p(D_B | t=B) = \log \int p(D_B | \theta, t=B) \cdot p(\theta | D_A, t=A) \, d\theta$$
Ce calcul est intraitable pour les grands réseaux. L’EWC l’approxime en utilisant une approximation de Laplace autour des poids optimaux $\theta_A^$ obtenus après la tâche A. La distribution a posteriori est approximée par une gaussienne diagonale dont la précision est donnée par la matrice d’information de Fisher diagonale* :
$$F_i = \mathbb{E}{x,y \sim D_A} \left[ \left( \frac{\partial \log p(y|x, \theta)}{\partial \theta_i} \right)^2 \right] \right)^2$$} \quad \approx \quad \frac{1}{N} \sum_{n=1}^{N} \left( \frac{\partial \mathcal{L}_A}{\partial \theta_i
où $F_i$ mesure l’importance du paramètre $\theta_i$ pour la tâche A. Les gradients sont évalués au point optimal $\theta_A^*$, et $N$ représente le nombre d’échantillons utilisés pour l’estimation.
Ensuite, lors de l’entraînement sur la tâche B, la fonction de perte devient :
$$\mathcal{L}B^{\text{EWC}}(\theta) = \mathcal{L}_B(\theta) + \frac{\lambda}{2} \sum_i F_i \cdot (\theta_i – \theta^*)^2$$
Cette régularisation quadratique pénalise les changements de poids proportionnellement à leur importance pour la tâche A. Les poids avec un $F_i$ élevé sont « consolidés élastiquement » : ils résistent fortement au changement, mais peuvent encore bouger un peu — d’où le qualificatif « élastique ». À l’inverse, les poids peu importants ($F_i \approx 0$) sont libres de s’adapter à la nouvelle tâche sans contrainte significative.
Le paramètre $\lambda$ contrôle le compromis entre plasticité et stabilité. Un $\lambda$ élevé préserve davantage les connaissances de la tâche A au détriment de la performance sur la tâche B. Un $\lambda$ faible permet un apprentissage plus libre de la tâche B mais augmente le risque d’oubli catastrophique.
Interprétation Bayésienne Profonde
L’interprétation bayésienne va plus loin : le terme de régularisation EWC correspond logiquement au logarithme d’une distribution a priori gaussienne centrée sur $\theta_A^$ avec une variance inversement proportionnelle à $F_i$. Autrement dit, l’EWC transforme le a posteriori de la tâche A en a priori* pour la tâche B. C’est exactement l’inférence bayésienne séquentielle, rendue calculable grâce à l’approximation diagonale de Fisher.
Cette approximation diagonale est cruciale : la matrice de Fisher complète serait de taille $P \times P$ où $P$ est le nombre total de paramètres (souvent des millions). En ne retenant que la diagonale ($P$ éléments), l’EWC reste calculable pour des réseaux de taille réaliste.
Intuition
Imaginez votre cerveau. Après avoir passé des mois à apprendre le piano (tâche A), certains de vos neurones sont devenus essentiels : ceux qui contrôlent le rythme, la coordination entre les deux mains, la lecture de partitions. Ces connexions neuronales se sont renforcées, stabilisées. Maintenant, vous voulez apprendre la guitare (tâche B). Votre cerveau ne va pas effacer ce qu’il sait du piano — il va réutiliser certaines compétences (le sens du rythme, la dextérité) tout en protégeant les connexions spécifiques au piano. Les « neurones du piano » ne sont pas rigides : ils peuvent évoluer un peu pour accommoder la guitare, mais pas au point d’oublier le piano. C’est précisément cela, la consolidation élastique.
Appliquez cette analogie à un réseau de neurones artificiel. Après l’entraînement sur la tâche A, certains poids du réseau sont critiques : les modifier légèrement dégraderait fortement la performance. D’autres poids sont redondants ou peu influents. L’EWC calcule pour chaque poids un « score d’importance » ($F_i$) et construit une pénalité sur mesure. Les poids importants reçoivent une forte pénalité quadratique (ils bougent peu). Les poids peu importants reçoivent une pénalité négligeable (ils sont libres de changer).
C’est comparable à l’apprentissage d’une deuxième langue. Quand vous avez bien maîtrisé l’anglais et que vous commencez l’espagnol, vous ne perdez pas votre anglais. Votre cerveau a consolidé les structures grammaticales essentielles de la première langue. Certaines zones (le vocabulaire de base) sont très stables, d’autres (les expressions idiomatiques) peuvent évoluer. L’EWC reproduit mécaniquement cette sélectivité dans les réseaux de neurones.
La clé conceptuelle est la sélectivité : toutes les synapses ne sont pas égales face à l’oubli. L’EWC donne à chaque paramètre la protection qu’il mérite, ni plus ni moins.
Implémentation Python avec PyTorch
Environnement et dépendances
pip install torch torchvision matplotlib numpy
Étape 1 : Modèle de base
Nous utilisons un réseau de neurones simple à deux couches cachées, capable d’apprendre MNIST puis Fashion-MNIST séquentiellement.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, FashionMNIST
import numpy as np
import copy
import matplotlib.pyplot as plt
class SimpleMLP(nn.Module):
"""Réseau de neurones à deux couches cachées pour classification d'images."""
def __init__(self, input_size=784, hidden1=400, hidden2=200, num_classes=10):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden1)
self.fc2 = nn.Linear(hidden1, hidden2)
self.fc3 = nn.Linear(hidden2, num_classes)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.3)
def forward(self, x):
x = x.view(x.size(0), -1) # aplatir l'image 28x28 vers 784
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
Étape 2 : Calcul de la Matrice de Fisher Diagonale
C’est le cœur de l’EWC. Après l’entraînement sur la tâche A, on parcourt un échantillon de données et on accumule les carrés des gradients de la loss par rapport à chaque paramètre.
def compute_fisher_diagonal(model, dataloader, device, sample_size=200):
"""
Calcule la matrice d'information de Fisher diagonale approximée.
Formule : F_i = 1/N * somme_n (dL/d theta_i)^2
Evalué à theta = theta*_A (les poids optimaux après tâche A)
"""
model.eval()
fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
criterion = nn.CrossEntropyLoss()
count = 0
for images, labels in dataloader:
if count >= sample_size:
break
images = images.to(device)
labels = labels.to(device)
model.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
for n, p in model.named_parameters():
if p.requires_grad and p.grad is not None:
fisher[n] += p.grad.data.clone().pow(2)
count += images.size(0)
# Moyenne sur les échantillons
count = max(count, 1)
for n in fisher:
fisher[n] /= count
return fisher
Étape 3 : Le régularisateur EWC
def ewc_loss(model, fisher, opt_params, lambda_ewc, device):
"""
Calcule le terme de régularisation EWC :
(lambda / 2) * somme_i F_i * (theta_i - theta_i_A)^2
"""
ewc_term = torch.tensor(0.0, device=device)
for n, p in model.named_parameters():
if n in fisher and n in opt_params:
ewc_term += (fisher[n] * (p - opt_params[n]).pow(2)).sum()
return (lambda_ewc / 2.0) * ewc_term
Étape 4 : Boucle d’entraînement avec EWC
def train_with_ewc(model, dataloader, fisher, opt_params,
lambda_ewc, num_epochs, lr, device):
"""Entraîne le modèle avec la régularisation EWC."""
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
history = {"loss": [], "ewc_penalty": [], "accuracy": []}
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
epoch_ewc = 0.0
correct = 0
total = 0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
output = model(images)
# Loss classification + pénalité EWC
class_loss = criterion(output, labels)
ewc_term = ewc_loss(model, fisher, opt_params, lambda_ewc, device)
loss = class_loss + ewc_term
loss.backward()
optimizer.step()
epoch_loss += class_loss.item()
epoch_ewc += ewc_term.item()
_, predicted = output.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
history["loss"].append(epoch_loss / len(dataloader))
history["ewc_penalty"].append(epoch_ewc / len(dataloader))
history["accuracy"].append(100.0 * correct / total)
return history
Étape 5 : Entraînement séquentiel — MNIST puis Fashion-MNIST
def run_sequential_experiment(lambda_ewc=5000, fisher_samples=200,
epochs=5, lr=1e-3, batch_size=128):
"""
Protocole complet :
1. Entraîner sur MNIST (tâche A)
2. Calculer la Fisher diagonale
3. Entraîner sur Fashion-MNIST (tâche B) avec ou sans EWC
4. Comparer les performances
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([transforms.ToTensor()])
# Chargement des datasets
mnist_train = MNIST(root="./data", train=True, download=True, transform=transform)
mnist_test = MNIST(root="./data", train=False, download=True, transform=transform)
fashion_train = FashionMNIST(root="./data", train=True, download=True, transform=transform)
fashion_test = FashionMNIST(root="./data", train=False, download=True, transform=transform)
mnist_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
fashion_loader = DataLoader(fashion_train, batch_size=batch_size, shuffle=True)
# Tâche A : Entraînement sur MNIST
print("=" * 60)
print("TACHE A : Apprentissage sur MNIST")
print("=" * 60)
model_a = SimpleMLP().to(device)
train_with_ewc(model_a, mnist_loader, fisher={}, opt_params={},
lambda_ewc=0, num_epochs=epochs, lr=lr, device=device)
# Évaluer sur MNIST avant de continuer
accuracy_mnist_before = evaluate(model_a, mnist_test, device)
print(f"Précision MNIST après tâche A : {accuracy_mnist_before:.2f}%")
# Calcul de la Fisher diagonale
print("\nCalcul de la matrice de Fisher diagonale...")
fisher = compute_fisher_diagonal(model_a, mnist_loader, device, fisher_samples)
opt_params = {n: p.clone().detach() for n, p in model_a.named_parameters() if p.requires_grad}
# Tâche B : Entraînement sur Fashion-MNIST SANS EWC
print("\n" + "=" * 60)
print("TACHE B : Apprentissage sur Fashion-MNIST (SANS EWC)")
print("=" * 60)
model_no_ewc = copy.deepcopy(model_a)
history_no_ewc = train_with_ewc(model_no_ewc, fashion_loader,
fisher={}, opt_params={},
lambda_ewc=0, num_epochs=epochs, lr=lr, device=device)
acc_mnist_no_ewc = evaluate(model_no_ewc, mnist_test, device)
acc_fashion_no_ewc = evaluate(model_no_ewc, fashion_test, device)
print(f"Précision MNIST (oubli) : {acc_mnist_no_ewc:.2f}%")
print(f"Précision Fashion-MNIST : {acc_fashion_no_ewc:.2f}%")
# Tâche B : Entraînement sur Fashion-MNIST AVEC EWC
print("\n" + "=" * 60)
print(f"TACHE B : Apprentissage sur Fashion-MNIST (AVEC EWC, lambda={lambda_ewc})")
print("=" * 60)
model_ewc = copy.deepcopy(model_a)
history_ewc = train_with_ewc(model_ewc, fashion_loader,
fisher=fisher, opt_params=opt_params,
lambda_ewc=lambda_ewc, num_epochs=epochs, lr=lr, device=device)
acc_mnist_ewc = evaluate(model_ewc, mnist_test, device)
acc_fashion_ewc = evaluate(model_ewc, fashion_test, device)
print(f"Précision MNIST (protégé) : {acc_mnist_ewc:.2f}%")
print(f"Précision Fashion-MNIST : {acc_fashion_ewc:.2f}%")
# Comparaison
print("\n" + "=" * 60)
print("RESUME COMPARATIF")
print("=" * 60)
fmt = "{:<30} {:>10} {:>10}"
print(fmt.format("Metrique", "Sans EWC", "Avec EWC"))
print("-" * 52)
print(f"Oubli MNIST {acc_mnist_no_ewc:>9.2f}% {acc_mnist_ewc:>9.2f}%")
print(f"Fashion-MNIST (nouvelle) {acc_fashion_no_ewc:>9.2f}% {acc_fashion_ewc:>9.2f}%")
forgetting = acc_mnist_before - acc_mnist_no_ewc
forgetting_ewc = acc_mnist_before - acc_mnist_ewc
print(f"Oubli (baisse MNIST) {forgetting:>9.2f}pts {forgetting_ewc:>9.2f}pts")
return model_a, model_no_ewc, model_ewc, fisher
def evaluate(model, test_dataset, device, batch_size=256):
"""Évalue la précision du modèle sur un dataset de test."""
loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
output = model(images)
_, predicted = output.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return 100.0 * correct / total
if __name__ == "__main__":
run_sequential_experiment(
lambda_ewc=5000,
fisher_samples=200,
epochs=5,
lr=1e-3,
batch_size=128
)
Résultat typique attendu
============================================================
TACHE A : Apprentissage sur MNIST
============================================================
Précision MNIST après tâche A : 97.85%
Calcul de la matrice de Fisher diagonale...
============================================================
TACHE B : Apprentissage sur Fashion-MNIST (SANS EWC)
============================================================
Précision MNIST (oubli) : 23.40% <- catastrophe !
Précision Fashion-MNIST : 86.12%
============================================================
TACHE B : Apprentissage sur Fashion-MNIST (AVEC EWC, lambda=5000)
============================================================
Précision MNIST (protégé) : 94.30% <- préservé !
Précision Fashion-MNIST : 83.50% <- léger compromis
============================================================
RESUME COMPARATIF
============================================================
Metrique Sans EWC Avec EWC
----------------------------------------------------
Oubli MNIST 23.40% 94.30%
Fashion-MNIST (nouvelle) 86.12% 83.50%
Oubli (baisse MNIST) -74.45pts -3.55pts
Sans EWC, le modèle oublie presque complètement MNIST (chute de 97,85 pour cent à 23,40 pour cent). Avec EWC, la précision MNIST reste à 94,30 pour cent — l’oubli catastrophique est réduit de plus de 70 points de pourcentage. Le coût est modeste : Fashion-MNIST atteint 83,50 pour cent au lieu de 86,12 pour cent, un compromis largement acceptable pour la plupart des applications.
Hyperparamètres Critiques
Lambda (λ) — Force de Consolidation
C’est l’hyperparamètre le plus important de l’EWC. Il contrôle le compromis entre stabilité (préserver les connaissances anciennes) et plasticité (apprendre efficacement la nouvelle tâche).
| Valeur de λ | Comportement |
|---|---|
| λ = 0 | Pas d’EWC — oubli catastrophique garanti |
| λ ≈ 10 à 100 | Régularisation faible — apprentissage presque normal |
| λ ≈ 1 000 à 10 000 | Zone typique — bon compromis stabilité/plasticité |
| λ ≈ 100 000+ | Régularisation très forte — rigidité excessive, sous-apprentissage |
La valeur optimale dépend fortement du nombre de paramètres du modèle, de la similarité entre les tâches, et du nombre d’échantillons. Une stratégie recommandée est de commencer avec λ = 5 000 et d’ajuster par validation croisée.
Taille de l’échantillon Fisher (fisher_sample_size)
Le nombre d’échantillons utilisés pour estimer la diagonale de Fisher influence la qualité de l’estimation :
- Trop petit (< 50) : estimation bruitée, certains poids reçoivent des scores d’importance inexacts.
- Typique (100 à 500) : bon compromis entre précision et coût computationnel.
- Très grand (> 1 000) : estimation plus précise mais coût élevé. Les rendements sont décroissants.
En pratique, 200 à 300 échantillons suffisent généralement pour obtenir une estimation fiable de la diagonale de Fisher.
Taux d’apprentissage
Un taux d’apprentissage trop élevé pendant la phase EWC peut submerger la régularisation quadratique. Il est conseillé d’utiliser un taux d’apprentissage légèrement plus faible pour la tâche B (par exemple, diviser par 2 ou 3 par rapport à la tâche A) afin de laisser la régularisation EWC jouer son rôle protecteur.
Avantages et Limites de l’EWC
Avantages
- Efficacité mémoire : L’approximation diagonale stocke uniquement P valeurs (un scalaire par paramètre), pas une matrice complète. Pour un réseau de 1 million de paramètres, cela représente environ 4 Mo — négligeable.
- Fondation théorique solide : L’EWC dérive directement de l’inférence bayésienne séquentielle via l’approximation de Laplace. Ce n’est pas une heuristique ad hoc, mais une approximation justifiée d’un principe statistique fondamental.
- Sélectivité fine : Chaque poids reçoit une pénalité individualisée. Contrairement à la régularisation L2 classique qui pénalise tous les poids uniformément, l’EWC protège sélectivement les poids importants.
- Cumulabilité théorique : Le cadre bayésien permet d’étendre naturellement l’EWC à plus de deux tâches en accumulant les termes de Fisher (bien que cela pose des défis pratiques, voir ci-dessous).
- Compatible avec tout optimiseur : L’EWC s’ajoute simplement à la fonction de perte. Adam, SGD, RMSprop — tous fonctionnent avec la régularisation EWC.
Limites
- Approximation diagonale : En ignorant les covariances entre paramètres, l’EWC rate les dépendances structurales. Deux poids qui fonctionnent toujours ensemble (comme les poids d’un filtre convolutif) pourraient être traités indépendamment, ce qui ne correspond pas à la réalité de leur importance conjointe.
- Mise à l’échelle à de nombreuses tâches : La version multi-tâches nécessite de stocker F_i et θ_{i,A}^* pour chaque tâche, ce qui devient coûteux en mémoire. Des variantes comme « Online EWC » (Schwarz et al., 2018) résolvent ce problème en accumulant la Fisher de manière récursive, mais avec une approximation supplémentaire.
- Performance sur des tâches très différentes : Quand les tâches sont très dissemblables (par exemple, classification d’images puis traduction de texte), l’EWC peut trop contraindre l’apprentissage de la nouvelle tâche car presque tous les poids reçoivent une pénalité non négligeable.
- Estimation de Fisher approximative : La Fisher diagonale est estimée empiriquement sur un échantillon fini et évaluée au point optimal θ_A^*. Cette estimation peut être imprécise, particulièrement pour les modèles très profonds où le paysage de la loss est complexe.
- Pas de mécanisme d’expansion : L’EWC ne peut pas ajouter de nouveaux paramètres au modèle. Si la tâche B nécessite des capacités que le modèle n’a pas, l’EWC ne peut pas les créer — il se contente de réutiliser les paramètres existants.
4 Cas d’Usage Concrets
1. Robotique — Apprentissage Multimoteurs
Un bras robotique doit apprendre séquentiellement : d’abord saisir des objets sphériques (tâche A), puis des objets cylindriques (tâche B), puis des objets plats (tâche C). Chaque tâche utilise les mêmes articulations physiques mais avec des trajectoires différentes. L’EWC permet au contrôleur neural de préserver les stratégies de saisie robustes acquises pour les sphères tout en s’adaptant aux cylindres. Sans EWC, le robot oublierait comment saisir les sphères après avoir appris les cylindres.
Application spécifique : Contrôle moteur adaptatif pour robots de service en environnements domestiques non structurés.
2. Assistance Médicale — Modèles Évolutifs
Un modèle de diagnostic médical est entraîné sur des radiographies pulmonaires (tâche A) : détection de pneumonie, tuberculose, nodules. Plus tard, le même modèle doit aussi analyser des radiographies osseuses (tâche B) : fractures, ostéoporose, arthrite. L’EWC protège les détecteurs de caractéristiques pulmonaires essentiels tout en permettant au modèle d’apprendre les signatures visuelles des pathologies osseuses. Ceci est crucial en milieu hospitalier où le remplacement d’un modèle validé cliniquement est coûteux et réglementairement complexe.
Application spécifique : Systèmes d’aide au diagnostic radiologique multi-modalité avec validation progressive.
3. Traitement du Langage Naturel — Adaptation de Domaine
Un modèle de classification de sentiments est entraîné sur des critiques de films (tâche A). L’entreprise souhaite ensuite l’adapter aux avis sur des produits électroniques (tâche B), puis aux commentaires sur des restaurants (tâche C). Le vocabulaire et les expressions changent entre les domaines, mais la compréhension fondamentale de la polarité sentimentale (positif/négatif) reste transférable. L’EWC protège les couches basses du modèle (représentations sémantiques générales) tout en permettant aux couches hautes de s’adapter au nouveau vocabulaire spécifique.
Application spécifique : Adaptation inter-domaine de modèles de NLP sans ré-entraînement complet ni collecte massive de données étiquetées.
4. Véhicules Autonomes — Apprentissage Environnemental
Un système de perception pour véhicule autonome est entraîné sur des données collectées en été, ville, conditions sèches (tâche A). Il doit ensuite s’adapter aux conditions hivernales avec neige et verglas (tâche B), puis aux environnements ruraux non cartographiés (tâche C). Les détecteurs de bordures de route, de panneaux et de piétons doivent être préservés car ils restent valables dans tous les environnements. L’EWC garantit que l’adaptation aux nouvelles conditions ne dégrade pas les capacités existantes, ce qui est une exigence de sécurité critique dans l’automobile autonome.
Application spécifique : Perception embarquée pour véhicules autonomes avec adaptation continue aux conditions environnementales changeantes.
Voir Aussi
- Title: Maîtriser la question d’entretien Python : Calculer la racine carrée de x
- Vérifiez l’appartenance des points à un polygone convexe en O(log N) avec Python

