Continual Learning : Guide Complet — Apprentissage Continu sans Oubli

Continual Learning : Guide Complet — Apprentissage Continu sans Oubli

Continual Learning : Guide Complet pour l’Apprentissage Continu sans Oubli Catastrophique

Résumé

Le Continual Learning (ou apprentissage continu) est un paradigme fondamental du machine learning où un modèle apprend séquentiellement une série de tâches, sans avoir accès aux données des tâches précédentes, tout en préservant ses connaissances antérieures. Contrairement à l’entraînement classique où toutes les données sont disponibles simultanément, le Continual Learning reflète la réalité du monde : les informations arrivent progressivement, et un système intelligent doit s’adapter sans effacer ce qu’il a déjà appris.

Le défi principal est le catastrophic forgetting (oubli catastrophique) : lorsqu’un réseau de neurones est réentraîné sur une nouvelle tâche, ses poids se modifient de manière destructive pour les compétences acquises précédemment. Ce phénomène rend les modèles classiques inadaptés aux environnements dynamiques où les données évoluent continuellement.

Trois grandes familles de solutions existent : la régularisation (comme Elastic Weight Consolidation), le replay (mémorisation d’échantillons passés), et l’architecture dynamique (expansion du modèle). Ce guide explore chacune de ces approches en détail, avec une implémentation Python complète sous PyTorch.


Principe Mathématique du Continual Learning

Le Problème du Catastrophic Forgetting

Imaginez un réseau de neurones entraîné sur une tâche A avec une fonction de perte ℒ_A(θ). Après convergence, les paramètres θ_A sont optimaux pour la tâche A. Si l’on entraîne ensuite exactement le même réseau sur une tâche B avec une perte ℒ_B(θ), les poids convergent vers θ_B. Problème : θ_B n’a aucun mécanisme pour préserver ce qui rendait θ_A performant sur la tâche A.

Mathématiquement, la perte sur la tâche A après entraînement sur B explose :

ℒ_A(θ_B) ≫ ℒ_A(θ_A)

Cela se produit car les gradients de ℒ_B poussent les poids dans des directions qui sont arbitraires par rapport au minimum de ℒ_A. Les neurones qui codaient des caractéristiques essentielles pour la tâche A sont réaffectés à la tâche B. C’est le catastrophic forgetting — le modèle « oublie » presque instantanément la tâche précédente.

Elastic Weight Consolidation (EWC)

Kirkpatrick et al. (2017) ont proposé une solution élégante inspirée de la consolidation synaptique dans le cerveau biologique. L’idée est de pénaliser les modifications des poids importants pour la tâche précédente, tout en permettant aux poids moins importants de s’adapter librement.

La fonction de perte combinée s’écrit :

ℒ(θ) = ℒ_B(θ) + λ · Σ_i F_i · (θ_i − θ_A_i)²

Où :
ℒ_B(θ) est la perte sur la nouvelle tâche B.
λ (lambda) est un hyperparamètre qui contrôle la force de la régularisation.
F_i est l’importance du paramètre θ_i pour la tâche A, estimée par la diagonale de la matrice de Fisher.
θ_A_i est la valeur optimale du paramètre i après entraînement sur la tâche A.
θ_i est la valeur courante du paramètre i pendant l’entraînement sur B.

La matrice de Fisher mesure la sensibilité de la distribution de sortie du modèle à chaque paramètre. Intuitivement, si modifier légèrement θ_i change beaucoup la sortie sur la tâche A, alors F_i est grande, et le paramètre est « important » — il doit être protégé.

En pratique, on utilise l’approximation diagonale de la matrice de Fisher, calculée sur un échantillon de données de la tâche A :

F_i ≈ 𝔼_{x~D_A}[(∂log p(y|x, θ)/∂θ_i)²]

Cette approximation rend EWC calculable : on n’a pas besoin de stocker toute la matrice de Fisher (qui serait de taille n² pour n paramètres), seulement sa diagonale (taille n).

Approche par Replay

Le replay consiste à mélanger un petit sous-ensemble de données des anciennes tâches avec les données de la nouvelle tâche pendant l’entraînement. Formellement, si D_old est le buffer de replay et D_new les données courantes :

ℒ(θ) = 𝔼_{(x,y)~D_new}[ℒ_new(θ; x, y)] + α · 𝔼_{(x,y)~D_old}[ℒ_old(θ; x, y)]

Cette approche est simple mais efficace. Le principal défi est de sélectionner quels échantillons garder dans le buffer : les plus représentatifs, les plus difficiles, ou une combinaison des deux ?

Distillation Progressive

La distillation utilise l’ancien modèle comme enseignant pour guider le nouveau. Pendant l’entraînement sur la tâche B, on ajoute un terme de régularisation basé sur la divergence de Kullback-Leibler entre les sorties de l’ancien modèle (figé) et du modèle courant :

ℒ(θ) = ℒ_new(θ) + β · KL(p_ancien(x) || p_courant(x))

La distillation est particulièrement utile en Class-Incremental Learning, où de nouvelles classes apparaissent séquentiellement, car l’ancien modèle contient déjà une connaissance de toutes les classes précédentes.


Intuition : Le Musicien et ses Instruments

Le Continual Learning, c’est comme un musicien professionnel qui apprend un nouvel instrument.

Imaginez un pianiste de concert qui décide d’apprendre le violon. S’il arrête complètement de pratiquer le piano pendant deux ans pour se concentrer uniquement sur le violon, ses compétences pianistiques vont se détériorer — c’est le catastrophic forgetting. Ses doigts perdent la mémoire musculaire, sa lecture de partitions au piano se ralentit, et ses interprétations deviennent moins nuancées.

Mais dans la réalité, un musicien ne fonctionne pas ainsi. Avec EWC, c’est comme s’il protégeait les compétences fondamentales qui servent aux deux instruments — la lecture de notes, le sens du rythme, l’oreille musicale — tout en adaptant les compétences spécifiques : le doigté pour le piano, l’embouchure pour le violon, la position des mains.

Les poids « importants » du réseau de neurones sont les compétences fondamentales : on les consolide, on les protège, on ne les laisse pas être écrasés par le nouvel apprentissage. Les poids « moins importants » sont les détails spécifiques : on les laisse évoluer librement pour s’adapter au nouvel instrument.

Avec le replay, le musicien pratique occasionnellement son ancien instrument pour maintenir le niveau. Avec la distillation, il écoute ses anciens enregistrements et essaie de reproduire la même qualité d’interprétation tout en intégrant sa nouvelle technique.

L’essence du Continual Learning, c’est protéger ce qui est essentiel tout en restant ouvert au nouveau. C’est l’équilibre entre stabilité et plasticité — le même équilibre que notre propre cerveau maintient chaque jour.


Implémentation Python Complète avec PyTorch

Voici une implémentation complète d’EWC appliquée à une séquence de tâches de classification d’images : MNIST → SVHN → MNIST-M. Le code utilise l’approximation diagonale de la matrice de Fisher et compare les performances avec un fine-tuning naïf.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import copy


class SimpleCNN(nn.Module):
    """
    Reseau convolutionnel simple pour la classification d'images.
    Architecture volontairement legere pour illustrer le continual learning.
    """
    def __init__(self, n_classes=10):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, n_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


class EWCRegularizer:
    """
    Regularisation Elastic Weight Consolidation (EWC).

    Calcule l'importance de chaque parametre via la diagonale
    de la matrice de Fisher apres entrainement sur une tache,
    puis penalise les modifications importantes lors des taches suivantes.
    """
    def __init__(self, model, device="cpu"):
        self.model = model
        self.device = device
        self.fisher = {}
        self.optimal_params = {}

    def compute_fisher(self, dataloader, n_samples=None):
        """
        Estime la diagonale de la matrice de Fisher.

        Pour chaque mini-batch, on calcule le gradient du loss
        par rapport a chaque parametre, puis on prend le carre
        moyen de ces gradients sur tous les echantillons.
        """
        self.model.eval()
        fisher_accumulator = {}

        for name, param in self.model.named_parameters():
            fisher_accumulator[name] = torch.zeros_like(param)

        total_samples = 0
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            if n_samples is not None and total_samples >= n_samples:
                break

            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            self.model.zero_grad()
            outputs = self.model(inputs)
            loss = nn.CrossEntropyLoss()(outputs, targets)
            loss.backward()

            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    fisher_accumulator[name] += param.grad.data ** 2

            total_samples += inputs.size(0)

        n_processed = min(total_samples, n_samples if n_samples else total_samples)
        for name in fisher_accumulator:
            self.fisher[name] = fisher_accumulator[name] / n_processed

        self.optimal_params = {
            name: param.clone().detach()
            for name, param in self.model.named_parameters()
        }

    def penalization(self, lam):
        """
        Calcule le terme de penalite EWC :
        lbd * sum_i F_i * (theta_i - theta_A_i)^2

        Args:
            lam (float): Coefficient lambda de regularisation.

        Returns:
            torch.Tensor: Terme de penalite a ajouter a la loss.
        """
        penalty = torch.tensor(0.0, device=self.device)

        for name, param in self.model.named_parameters():
            if name in self.fisher and name in self.optimal_params:
                importance = self.fisher[name]
                optimal = self.optimal_params[name].to(self.device)
                penalty += (importance * (param - optimal) ** 2).sum()

        return lam * penalty

    def merge_fisher(self, new_fisher, blending=0.5):
        """
        Fusionne une nouvelle estimation de Fisher avec les precedentes.
        Utile quand on enchaine plusieurs taches successivement.
        """
        for name in self.fisher:
            if name in new_fisher:
                self.fisher[name] = blending * self.fisher[name] + \
                                    (1 - blending) * new_fisher[name]


def train_epoch(model, dataloader, optimizer, ewc=None, lambda_ewc=0.0, device="cpu"):
    """
    Entraine le modele pour une epoque complete.

    Si EWC est active, ajoute le terme de penalite a la perte.
    """
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    criterion = nn.CrossEntropyLoss()

    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        if ewc is not None and lambda_ewc > 0:
            loss += ewc.penalization(lambda_ewc)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()
        total += targets.size(0)

    avg_loss = total_loss / total
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy


@torch.no_grad()
def evaluate(model, dataloader, device="cpu"):
    """Evalue le modele sur un jeu de donnees."""
    model.eval()
    correct = 0
    total = 0

    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()
        total += targets.size(0)

    return 100.0 * correct / total


def run_continual_learning(device="cuda" if torch.cuda.is_available() else "cpu"):
    """
    Execute la sequence experimentale avec et sans EWC.
    Compare les performances pour demontrer l'efficacite
    du Continual Learning contre le fine-tuning naif.
    """
    print(f"== Continual Learning : EWC sur MNIST -> SVHN -> MNIST-M ==")
    print(f"Device : {device}\n")

    n_classes = 10
    lambda_ewc = 5000.0
    fisher_samples = 2000
    n_epochs = 5

    n_train = 1000
    n_test = 200

    train_loader = DataLoader(
        TensorDataset(torch.randn(n_train, 3, 28, 28),
                      torch.randint(0, n_classes, (n_train,))),
        batch_size=128, shuffle=True
    )
    test_loader = DataLoader(
        TensorDataset(torch.randn(n_test, 3, 28, 28),
                      torch.randint(0, n_classes, (n_test,))),
        batch_size=128
    )

    tasks = ["Tache 1 (MNIST-like)", "Tache 2 (SVHN-like)", "Tache 3 (MNIST-M-like)"]

    print("--- Protocole 1 : Fine-tuning naif (baseline) ---")
    model_naive = SimpleCNN(n_classes).to(device)
    results_naive = {}

    for t, task_name in enumerate(tasks):
        print(f"\n  Entrainement sur {task_name}...")
        optimizer = optim.Adam(model_naive.parameters(), lr=1e-3)

        for epoch in range(n_epochs):
            loss, acc = train_epoch(
                model_naive, train_loader, optimizer, device=device
            )
            print(f"    Epoque {epoch+1}/{n_epochs} -- Loss: {loss:.4f}, Acc: {acc:.1f}%")

        acc_current = evaluate(model_naive, test_loader, device)
        results_naive[task_name] = acc_current
        print(f"  -> Precision sur {task_name} : {acc_current:.1f}%")

    print(f"\n  Resume fine-tuning naif : {results_naive}")

    print("\n--- Protocole 2 : EWC (Continual Learning) ---")
    model_ewc = SimpleCNN(n_classes).to(device)
    ewc = EWCRegularizer(model_ewc, device=device)
    results_ewc = {}

    for t, task_name in enumerate(tasks):
        print(f"\n  Entrainement sur {task_name}...")
        optimizer = optim.Adam(model_ewc.parameters(), lr=1e-3)

        for epoch in range(n_epochs):
            loss, acc = train_epoch(
                model_ewc, train_loader, optimizer,
                ewc=ewc if t > 0 else None,
                lambda_ewc=lambda_ewc if t > 0 else 0.0,
                device=device
            )
            print(f"    Epoque {epoch+1}/{n_epochs} -- Loss: {loss:.4f}, Acc: {acc:.1f}%")

        if t < len(tasks) - 1:
            ewc.compute_fisher(test_loader, n_samples=fisher_samples)

        acc_current = evaluate(model_ewc, test_loader, device)
        results_ewc[task_name] = acc_current
        print(f"  -> Precision sur {task_name} : {acc_current:.1f}%")

    print(f"\n  Resume EWC : {results_ewc}")

    print("\n" + "=" * 60)
    print("COMPARAISON DES PERFORMANCES")
    print("=" * 60)
    for task in tasks:
        naive_acc = results_naive.get(task, 0.0)
        ewc_acc = results_ewc.get(task, 0.0)
        diff = ewc_acc - naive_acc
        marker = "OK" if diff > 0 else "X "
        print(f"{task:<30}  Naif: {naive_acc:>6.1f}%  EWC: {ewc_acc:>6.1f}%  [{marker}]")

    print("\nNote: Avec de vraies donnees, EWC montre typiquement")
    print("une superiorite claire sur le fine-tuning naif,")
    print("particulierement sur les taches plus anciennes.")


if __name__ == "__main__":
    run_continual_learning()

Ce code démontre les composants essentiels d’un pipeline de Continual Learning :

  1. Calcul de la diagonale de Fisher : accumulation des gradients au carré sur des échantillons de la tâche précédente, normalisée par le nombre d’échantillons.
  2. Pénalisation EWC : pour chaque paramètre, on mesure l’écart par rapport à la valeur optimale précédente, pondéré par l’importance Fisher.
  3. Comparaison directe : le protocole naïf entraîne séquentiellement sans protection, tandis que le protocole EWC ajoute la régularisation après la première tâche.

Pour une utilisation en production avec de vraies données, il suffit de remplacer les données simulées par torchvision.datasets.MNIST, SVHN, et des variantes transformées comme MNIST-M (MNIST avec fond de Berkeley Segmentation Dataset).


Hyperparamètres Clés du Continual Learning

Le choix des hyperparamètres est crucial en Continual Learning. Voici les plus importants :

λ (ewc_lambda) — Force de régularisation EWC

  • Typique : 1 000 à 50 000 selon l’échelle de la perte.
  • Trop bas : le modèle oublie les tâches précédentes (oubli catastrophique non maîtrisé).
  • Trop haut : le modèle ne peut pas apprendre de nouvelles tâches (rigidité excessive).
  • Règle pratique : λ devrait être de l’ordre de grandeur du ratio entre la variance des gradients de la nouvelle tâche et l’importance Fisher. Un balayage logarithmique (100, 1 000, 10 000, 50 000) est recommandé.

fisher_samples — Échantillons pour estimer Fisher

  • Typique : 500 à 5 000 échantillons par tâche.
  • Plus d’échantillons donnent une estimation plus précise de l’importance des paramètres.
  • Cependant, au-delà d’un certain seuil, le gain marginal diminue.
  • Dans un scénario réel, on ne veut pas passer trop de temps à estimer Fisher au détriment de l’entraînement.

memory_size — Taille du buffer de replay

  • Typique : 100 à 5 000 échantillons par tâche ancienne.
  • Plus le buffer est grand, meilleure est la rétention des anciennes tâches.
  • Stratégies de sélection avancées : reservoir sampling, échantillons les plus difficilement classifiables, ou sélection par gradient.
  • Le compromis mémoire/performance est un axe de recherche actif.

β (beta_distillation) — Coefficient de distillation

  • Typique : 0.1 à 10.0.
  • Utilisé en combinaison avec EWC ou replay pour renforcer la préservation des connaissances.
  • En Class-Incremental Learning, la distillation est souvent plus efficace que EWC seule, car elle préserve les relations entre classes, pas seulement les poids individuels.

Avantages et Limites

Avantages

  1. Adaptabilité continue : Le modèle évolue avec de nouvelles données sans nécessiter de réentraînement complet depuis zéro. C’est indispensable pour les systèmes déployés en production qui doivent s’adapter en temps réel.
  2. Économie de stockage : Contrairement au réentraînement batch qui nécessite de conserver toutes les données historiques, le Continual Learning ne garde qu’un petit buffer de replay (quelques centaines d’échantillons par tâche) ou même aucun échantillon (approches par régularisation pure comme EWC).
  3. Respect de la vie privée : Les données anciennes n’ont pas besoin d’être stockées indéfiniment, ce qui réduit les risques de fuite et facilite la conformité RGPD. L’apprentissage se fait en flux continu, les données peuvent être supprimées après utilisation.
  4. Apprentissage plus naturel : Le pattern d’apprentissage séquentiel reflète la manière dont les humains acquièrent des connaissances — progressivement, en construisant sur ce qui est déjà connu.

Limites

  1. Compromis stabilité-plasticité : Trop de régularisation empêche l’apprentissage de nouvelles tâches ; trop peu provoque l’oubli catastrophique. Trouver le bon équilibre est difficile et dépend du domaine.
  2. Complexité algorithmique : Les méthodes avancées comme EWC nécessitent le calcul de la matrice de Fisher, ce qui ajoute un overhead computationnel significatif. Les méthodes par replay nécessitent une gestion complexe du buffer de mémoire.
  3. Dégradation cumulative : Même avec EWC, les performances sur les toutes premières tâches tendent à diminuer progressivement au fil des séquences, car chaque nouvelle tâche introduit une petite perturbation qui s’accumule.
  4. Évaluation difficile : Les benchmarks standards (Split-MNIST, Permuted-MNIST, CORe50) ne reflètent pas toujours la réalité des applications pratiques où la distribution des données change de manière plus subtile et non stationnaire.

4 Cas d’Usage Concrets du Continual Learning

1. Véhicules Autonomes — Adaptation aux Conditions Routières

Un véhicule autonome déployé à Paris doit s’adapter aux panneaux de signalisation spécifiques d’autres villes ou pays (États-Unis, Japon) sans oublier comment conduire à Paris. Le Continual Learning permet au modèle de perception d’apprendre progressivement les nouvelles régulations, conditions météorologiques extrêmes, et styles de conduite locaux, tout en maintenant ses compétences de base.

Les implications de sécurité rendent l’oubli catastrophique inacceptable : un véhicule qui « oublierait » comment détecter les piétons après avoir appris à reconnaître les nouveaux panneaux serait catastrophique. EWC et le replay sont des candidats naturels pour ce scénario.

2. Assistants Vocaux — Apprentissage des Préférences Utilisateurs

Un assistant vocal comme Alexa ou Siri doit apprendre les préférences, habitudes et patterns de langage de chaque utilisateur individuellement, sans que l’apprentissage sur un utilisateur ne détériore les performances sur les autres. Le Continual Learning permet une personnalisation incrémentale : le modèle de base reste stable (grâce à la régularisation EWC), tandis qu’une petite partie des poids s’adapte au profil de l’utilisateur.

C’est aussi un enjeu de confidentialité : les données vocales des utilisateurs n’ont pas besoin d’être centralisées sur un serveur pour un réentraînement global. L’apprentissage se fait localement, de manière continue.

3. Diagnostic Médical — Intégration de Nouvelles Pathologies

Un système d’aide au diagnostic radiologique doit intégrer la détection de nouvelles pathologies (comme une épidémie émergente) sans perdre sa capacité à diagnostiquer les maladies pour lesquelles il a été initialement entraîné. Le Continual Learning permet d’ajouter des classes progressivement : d’abord les tumeurs courantes, puis les rares, puis les émergentes.

La distillation progressive est particulièrement pertinente ici : le modèle précédent sert de référence pour s’assurer que les prédictions sur les classes existantes ne dérivent pas pendant l’ajout de nouvelles classes.

4. Robots Industriels — Apprentissage de Nouvelles Tâches

Un bras robotique dans une usine apprend successivement à assembler différents produits au fil des changements de ligne de production. Plutôt que de réentraîner le modèle de contrôle de zéro pour chaque nouveau produit (ce qui prendrait des jours), le Continual Learning permet un ajustement rapide : le robot conserve sa compréhension fondamentale de la manipulation d’objets tout en apprenant les gestes spécifiques au nouveau produit.

Le replay est particulièrement adapté ici : quelques exemples de chaque tâche précédente suffisent à maintenir les compétences pendant l’acquisition des nouvelles.


Voir Aussi

Laisser un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *

Ce site utilise Akismet pour réduire les indésirables. En savoir plus sur la façon dont les données de vos commentaires sont traitées.