Knowledge Distillation : Guide Complet — Distillation de Connaissances

Knowledge Distillation : Guide Complet — Distillation de Connaissances

Knowledge Distillation : Compression de Modèles par Distillation de Connaissances

Résumé

La Knowledge Distillation (distillation de connaissances) est une technique fondamentale de compression de modèles en apprentissage profond. Introduite par Hinton et ses collaborateurs en 2015, cette approche permet de transférer les connaissances accumulées par un grand modèle (appelé teacher ou enseignant) vers un modèle nettement plus compact (appelé student ou étudiant). Le mécanisme central repose sur l’utilisation de soft labels — des distributions de probabilité adoucies par un paramètre de température — qui encodent bien plus d’informations que les étiquettes classiques. Grâce à cette méthode, un modèle compressé peut atteindre des performances remarquables, souvent supérieures à celles d’un modèle de même taille entraîné directement sur les données brutes. Ce guide complet explore le principe mathématique, l’intuition pédagogique, l’implémentation pratique en PyTorch, ainsi que les cas d’usage concrets de la Knowledge Distillation.

Principe Mathématique de la Knowledge Distillation

Le fonctionnement de la Knowledge Distillation repose sur des bases mathématiques solides et élégantes. Comprendre ces fondements est essentiel pour maîtriser pleinement cette technique de compression.

Les Soft Labels et la Température

Dans un réseau de neurones classique, la dernière couche produit un vecteur de logits z = (z₁, z₂, …, z_K) pour K classes. La fonction softmax standard convertit ces logits en probabilités :

p_i = exp(z_i) / Σ_j exp(z_j)

Le problème de cette formulation est qu’elle tend à produire des distributions très tranchées : une classe domine largement, et toutes les autres reçoivent des probabilités proches de zéro. L’information sur les similarités entre classes est ainsi perdue.

La Knowledge Distillation introduit un paramètre crucial : la température T. Le teacher génère des soft labels selon la formule suivante :

q_i = softmax(z_i / T) = exp(z_i / T) / Σ_j exp(z_j / T)

où T est la température. Plus T est élevé, plus la distribution de probabilité est douce et uniforme. À l’inverse, quand T se rapproche de 1, on retrouve la softmax classique. Par exemple, avec une température de T = 5, les probabilités deviennent beaucoup plus nuancées, révélant des similarités subtiles entre classes qui seraient autrement masquées.

La Fonction de Loss Combinée

L’entraînement du student utilise une fonction de coût composite qui combine deux sources de supervision :

L = α · T² · KL(softmax(teacher/T) || softmax(student/T)) + (1 – α) · CrossEntropy(student, hard_labels)

où :

  • α est un coefficient de pondération (généralement compris entre 0,5 et 0,9) qui détermine l’importance relative de la distillation par rapport à l’entraînement classique ;
  • est un facteur de mise à l’échelle indispensable : la divergence de Kullback-Leibler produite par des softmax à température élevée génère des gradients d’amplitude réduite. Le facteur T² compense cette diminution et maintient des gradients d’échelle comparable entre les deux termes de la loss ;
  • KL désigne la divergence de Kullback-Leibler, qui mesure l’écart entre la distribution du teacher et celle du student ;
  • CrossEntropy est l’entropie croisée classique entre les prédictions du student et les étiquettes véritables (hard labels).

Le Résultat Fondamental de Hinton et al. (2015)

Dans leur article fondateur « Distilling the Knowledge in a Neural Network », Geoffrey Hinton, Oriol Vinyals et Jeff Dean ont démontré un résultat contre-intuitif mais fondamental : le student distillé surpasse systématiquement un modèle de même architecture entraîné directement sur les hard labels. Cette supériorité s’explique par le fait que les soft labels véhiculent une richesse informationnelle considérable — le teacher encode non seulement la bonne réponse, mais aussi toute la structure relationnelle entre les classes, acquise au fil de son entraînement sur un vaste corpus de données.

Intuition : Le Professeur et l’Étudiant

Pour saisir véritablement l’essence de la Knowledge Distillation, l’analogie pédagogique est particulièrement éclairante.

Imaginez un professeur brillant (le teacher, un grand modèle profondément entraîné) face à un étudiant prometteur (le student, un petit modèle compact). Dans un scénario d’entraînement classique, le professeur se contenterait de donner la réponse correcte à chaque question : « Ceci est un chat. » Point final. L’étudiant apprend, mais de manière superficielle.

Dans le cadre de la Knowledge Distillation, le professeur adopte une approche bien plus riche et nuancée. Il ne donne pas seulement la réponse, il partage son raisonnement :

« C’est clairement un chat, à environ 85 %. Mais observez bien : on peut aussi distinguer une légère ressemblance avec le lynx, peut-être à 10 %, et une pointe de similarités avec le léopard, disons 5 %. Les oreilles sont un peu plus pointues que celles d’un chat domestique, et le pelage présente des motifs caractéristiques. »

Ces soft labels — ces probabilités nuancées entre toutes les classes — contiennent infiniment plus d’information que le seul label « chat ». Ils transmettent au student une compréhension profonde des similarités et des différences entre catégories. L’étudiant apprend ainsi à capturer la structure latente de l’espace des classes, bien au-delà de la simple reconnaissance de la catégorie majoritaire.

Cette approche présente un avantage décisif : même si le student n’a jamais rencontré un lynx durant son entraînement, les similarités apprises via les soft labels lui permettront de mieux généraliser lorsqu’il sera confronté à des exemples ambigus ou à des classes proches. C’est précisément cette capacité de généralisation améliorée qui distingue la Knowledge Distillation d’un simple entraînement sur des étiquettes binaires.

Implémentation Python avec PyTorch

Voici une implémentation complète et fonctionnelle de la Knowledge Distillation en PyTorch, illustrant le transfert de connaissances d’un ResNet18 pré-entraîné vers un petit réseau convolutionnel personnalisé.

Code Complet

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision import models

# ============================================================
# 1. Chargement des données (CIFAR-10)
# ============================================================
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                            download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

# ============================================================
# 2. Modèle Teacher : ResNet18 pré-entraîné (figé)
# ============================================================
teacher = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
# Adapter la dernière couche pour CIFAR-10 (10 classes)
teacher.fc = nn.Linear(512, 10)
teacher.eval()  # Mode évaluation — le teacher est déjà entraîné

# ============================================================
# 3. Modèle Student : petit CNN personnalisé
# ============================================================
class StudentCNN(nn.Module):
    def __init__(self):
        super(StudentCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, 10),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

student = StudentCNN()

# ============================================================
# 4. Fonction de Loss de Distillation
# ============================================================
def distillation_loss(student_logits, teacher_logits, hard_labels,
                      temperature=5.0, alpha=0.7):
    """
    Fonction de loss combinee pour la Knowledge Distillation.

    Args:
        student_logits : sorties brutes du modele etudiant
        teacher_logits : sorties brutes du modele enseignant
        hard_labels    : etiquettes reelles (ground truth)
        temperature    : parametre T pour adoucir les distributions
        alpha          : poids du terme de distillation (0 a 1)

    Returns:
        loss totale combinee
    """
    # Soft targets : softmax a temperature elevee
    soft_student = torch.log_softmax(student_logits / temperature, dim=1)
    soft_teacher = torch.softmax(teacher_logits / temperature, dim=1)

    # Divergence KL mise a l'echelle par T^2
    kl_loss = nn.KLDivLoss(reduction='batchmean')(soft_student, soft_teacher)
    kl_loss = kl_loss * (temperature ** 2)

    # Cross-entropy classique avec les hard labels
    ce_loss = nn.CrossEntropyLoss()(student_logits, hard_labels)

    # Combinaison ponderee
    total_loss = alpha * kl_loss + (1 - alpha) * ce_loss
    return total_loss

# ============================================================
# 5. Boucle d'Entraînement
# ============================================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
teacher = teacher.to(device)
student = student.to(device)

optimizer = optim.Adam(student.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

num_epochs = 50
best_accuracy = 0.0

for epoch in range(num_epochs):
    # ----- Phase d'entrainement -----
    student.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Inference du teacher (pas de gradient necessaire)
        with torch.no_grad():
            teacher_logits = teacher(images)

        # Forward pass du student
        student_logits = student(images)

        # Calcul de la loss de distillation
        loss = distillation_loss(
            student_logits, teacher_logits, labels,
            temperature=5.0, alpha=0.7
        )

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = student_logits.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_acc = 100.0 * correct / total

    # ----- Evaluation -----
    student.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = student(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    test_acc = 100.0 * correct / total
    scheduler.step()

    if test_acc > best_accuracy:
        best_accuracy = test_acc
        torch.save(student.state_dict(), 'best_student_distilled.pth')

    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch+1}/{num_epochs} - '
              f'Loss: {running_loss/len(train_loader):.4f} - '
              f'Train acc: {train_acc:.2f}% - '
              f'Test acc: {test_acc:.2f}% - '
              f'Best: {best_accuracy:.2f}%')

print(f'\nEntrainement termine! Meilleure precision: {best_accuracy:.2f}%')

Comparaison des Résultats Attendus

Un étudiant typique entraîné par distillation sur CIFAR-10 avec un ResNet18 comme teacher atteint fréquemment des précisions de l’ordre de 78 à 82 %, contre 72 à 76 % pour un modèle de même architecture entraîné directement sur les hard labels. Cet écart de plusieurs points de pourcentage illustre concrètement l’avantage informationnel procuré par les soft labels.

Hyperparamètres Clés

Le succès de la Knowledge Distillation dépend fortement du réglage minutieux de plusieurs hyperparamètres essentiels :

Température (T)

La température est de loin le paramètre le plus influent. Elle contrôle le degré d’adoucissement des distributions de probabilité :

  • T = 1 : équivalent à la softmax standard, aucune distillation effective ;
  • T = 3 à 5 : valeurs typiquement recommandées pour la plupart des tâches de classification d’images ;
  • T = 10 à 20 : utiles lorsque les classes sont très nombreuses (ImageNet avec 1000 classes) ;
  • T > 20 : risque de rendre la distribution trop uniforme, perdant ainsi toute information discriminante.

Coefficient Alpha (α)

Le coefficient α détermine l’équilibre entre l’apprentissage par distillation et l’apprentissage supervisé classique :

  • α = 0,7 à 0,9 : accorde la priorité au transfert de connaissances du teacher ;
  • α = 0,5 : équilibre parfait entre les deux sources de supervision ;
  • α < 0,5 : approche conservatrice, utile quand le teacher n’est pas parfaitement fiable.

Architecture du Student

Le choix de l’architecture étudiante est crucial et dépend directement des contraintes de déploiement :

  • Petit CNN : idéal pour les appareils embarqués et les applications mobiles ;
  • Réseau shallow (peu de couches) : adapté aux environnements à faible puissance de calcul ;
  • MobileNet ou EfficientNet réduit : excellent compromis entre précision et efficacité ;
  • Réduction du nombre de filtres : diviser par 2 ou 4 le nombre de canaux chaque couche constitue une stratégie efficace.

Architecture du Teacher

Un teacher plus performant et plus vaste génère des soft labels de meilleure qualité :

  • ResNet50, ResNet101 : enseignants robustes pour la classification d’images ;
  • EfficientNet-B7 : excellent rapport précision-capacité ;
  • Ensemble de modèles : la moyenne des prédictions de plusieurs teachers produit des soft labels encore plus riches et plus stables.

Avantages et Limites de la Knowledge Distillation

Avantages

  1. Compression spectaculaire : réduction de la taille du modèle d’un facteur 10 à 50 tout en conservant l’essentiel des performances originales.
  2. Inférence accélérée : le student, nettement plus compact, exécute les prédictions de manière significativement plus rapide, avec une consommation énergétique réduite.
  3. Généralisation améliorée : le student distillé surpasse souvent un modèle entraîné directement sur les données, grâce à la richesse informationnelle des soft labels.
  4. Compatibilité universelle : fonctionne avec pratiquement toutes les architectures de réseaux de neurones, de la classification d’images au traitement du langage naturel.
  5. Déploiement facilité : un modèle compressé est beaucoup plus simple à déployer sur des appareils aux ressources limitées.

Limites

  1. Coût initial élevé : l’entraînement et le stockage du teacher représentent un investissement computationnel conséquent.
  2. Dépendance au teacher : la qualité du student est intrinsèquement liée à celle du teacher. Un teacher médiocre produira inévitablement un student tout aussi médiocre.
  3. Réglage délicat des hyperparamètres : la température et le coefficient α nécessitent une validation croisée rigoureuse pour chaque tâche spécifique.
  4. Architecture contrainte : le student doit être capable de traiter les mêmes types d’entrées que le teacher, ce qui limite certaines optimisations architecturales.
  5. Surcharge computationnelle à l’entraînement : chaque étape d’entraînement nécessite un forward pass complet du teacher, doublant pratiquement le coût computationnel.

4 Cas d’Usage Concrets de la Knowledge Distillation

1. Déploiement Mobile et Embarqué

Les applications mobiles exigent des modèles légers, rapides et économes en énergie. Un modèle de vision par ordinateur de 200 Mo est totalement inutilisable sur un smartphone. Grâce à la Knowledge Distillation, on compresse ce modèle en un student de 5 à 10 Mo, parfaitement adapté au déploiement sur appareil mobile, tout en conservant une précision très proche du modèle original.

2. Traitement du Langage Naturel (NLP)

Les grands modèles de langage contemporains comme BERT Large (340 millions de paramètres) ou GPT-3 (175 milliards de paramètres) sont impossibles à déployer en production directe à grande échelle. La Knowledge Distillation permet de créer des versions compactes — comme DistilBERT (66 millions de paramètres, soit une réduction de 40 %, tout en conservant 97 % des performances de BERT) — qui tournent efficacement sur des serveurs standards avec une latence minimale.

3. Détection d’Anomalies en Temps Réel

Dans les environnements industriels, la détection d’anomalies doit s’exécuter en temps réel sur du matériel embarqué aux ressources extrêmement limitées. Un modèle teacher massivement entraîné sur des données historiques produit des connaissances riches sur les patterns normaux et anormaux. La distillation permet de transférer cette expertise vers un petit modèle capable de fonctionner en temps réel directement sur les capteurs industriels.

4. Systèmes de Recommandation à Grande Échelle

Les plateformes de recommandation modernes (e-commerce, streaming vidéo, réseaux sociaux) traitent des milliards d’interactions quotidiennes. Un teacher profond, entraîné sur l’ensemble des données historiques, capture des patterns complexes de préférences utilisateur. La Knowledge Distillation permet de créer des modèles students légers capables de scoring en temps réel avec une latence inférieure à la milliseconde, tout en préservant la qualité des recommandations.

Voir Aussi


Laisser un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *

Ce site utilise Akismet pour réduire les indésirables. En savoir plus sur la façon dont les données de vos commentaires sont traitées.