Batch Normalization : Guide Complet — Normalisation par Lots

Batch Normalization : Guide Complet — Normalisation par Lots

Batch Normalization

Résumé

Le Batch Normalization (ou normalisation par lots) est l’une des techniques les plus influentes publiées dans le domaine du deep learning depuis l’avènement des réseaux de neurones convolutifs. Proposée par Ioffe et Szegedy en 2015 dans leur article fondateur « Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift », cette méthode a révolutionné la façon dont nous entraînons les réseaux profonds modernes.

Avant le Batch Normalization, l’entraînement de réseaux très profonds était un véritable cauchemar d’ingénierie. Les praticiens devaient régler méticuleusement les taux d’apprentissage, initialiser soigneusement les poids, et souvent faire face à des gradients qui explosaient ou disparaissaient au fil des couches. Le Batch Normalization a changé la donne en permettant l’entraînement de réseaux de centaines de couches avec une stabilité remarquable.

Aujourd’hui, le Batch Normalization est intégré dans pratiquement toutes les architectures modernes : ResNet, DenseNet, EfficientNet, et bien d’autres. Ce guide complet explore son fonctionnement mathématique, son intuition profonde, son implémentation pratique en Python avec PyTorch, et ses cas d’usage concrets.

Principe mathématique

Le principe du Batch Normalization est élégant dans sa simplicité. L’idée centrale consiste à normaliser les activations de chaque couche pour chaque mini-batch (lot) d’entraînement. Voici les équations fondamentales qui gouvernent cette technique.

Pour un mini-batch B contenant m échantillons, on calcule d’abord la moyenne et la variance des activations :

  • Moyenne du lot : μB = (1/m) · Σ xi — c’est la moyenne de toutes les activations du batch courant
  • Variance du lot : σ²B = (1/m) · Σ (xi − μB)² — c’est la variance mesurant la dispersion autour de la moyenne

Ensuite, on normalise chaque activation :

  • Activation normalisée :i = (xi − μB) / √(σ²B + ε)

Le terme ε (epsilon) est une petite constante numérique, typiquement 10⁻⁵, ajoutée pour la stabilité numérique afin d’éviter la division par zéro lorsque la variance est nulle ou extrêmement faible.

Cependant, normaliser systématiquement à moyenne nulle et variance unité serait trop restrictif : cela empêcherait le réseau de représenter des distributions qui nécessiteraient naturellement d’autres paramètres. La solution est brillante :

  • Transformation affine apprenable : yi = γ · x̂i + β

Les paramètres γ (gamma) et β (bêta) sont des paramètres apprenables que le réseau ajuste pendant l’entraînement par rétropropagation. Ils permettent au modèle de restaurer n’importe quelle distribution qu’il jugera utile — y compris l’identité (γ = √σ²B, β = μB), ce qui revient à annuler complètement la normalisation si le réseau le souhaite.

Phase d’inférence

Pendant l’entraînement, le Batch Normalization utilise les statistiques du batch courant (μB et σ²B). Mais en phase d’inférence (prédiction), nous n’avons souvent qu’un seul échantillon — il est donc impossible de calculer une moyenne et une variance de lot. La solution consiste à maintenir des statistiques cumulatives pendant l’entraînement :

  • running_mean : moyenne exponentielle mobile des μB
  • running_var : moyenne exponentielle mobile des σ²B

Ces statistiques sont mises à jour à chaque étape d’entraînement selon la formule :

  • runningnew = (1 − momentum) · runningold + momentum · μB

Le momentum (par défaut 0,1 dans PyTorch) contrôle la vitesse d’adaptation : un momentum faible donne plus de poids aux statistiques historiques, tandis qu’un momentum élevé rend les statistiques running plus réactives aux changements récents. Pendant l’inférence, on utilise ces running_mean et running_var à la place des statistiques du batch courant.

Intuition

Pour bien comprendre pourquoi le Batch Normalization fonctionne si bien, il faut d’abord comprendre le problème qu’il résout : le internal covariate shift (décalage de covariance interne).

Imaginez que vous essayiez d’apprendre à marcher, mais que le sol se dérobe constamment sous vos pieds. Chaque fois que vous faites un pas, la surface change — tantôt en pente, tantôt bosselée, tantôt glissante. C’est exactement ce qui se passe dans un réseau profond sans Batch Normalization. Chaque couche modifie la distribution des données qu’elle transmet à la couche suivante. La deuxième couche reçoit une distribution qui évolue au fur et à mesure que la première couche ajuste ses poids. La troisième couche reçoit une distribution qui a été transformée par deux couches en évolution constante, et ainsi de suite.

Le Batch Normalization stabilise ce sol mouvant. En normalisant les activations de chaque couche, il garantit que chaque couche suivante reçoit des données avec une distribution relativement stable — centrée et réduite. C’est l’équivalent d’un régulateur de vitesse dans une voiture : ça maintient le régime constant même dans les côtes et les descentes, libérant le conducteur (le réseau) pour qu’il se concentre sur l’essentiel — la direction.

Les avantages de cette stabilisation sont multiples :

  1. Taux d’apprentissage plus élevé : puisque les gradients sont mieux conditionnés, on peut utiliser des learning rates plus agressifs sans risquer la divergence.
  2. Convergence plus rapide : le réseau atteint une bonne performance en beaucoup moins d’époques.
  3. Régularisation implicite : le bruit introduit par l’estimation des statistiques sur un mini-batch (plutôt que sur l’ensemble du jeu de données) agit comme un régularisateur, réduisant parfois le besoin de Dropout.
  4. Moindre sensibilité à l’initialisation : le réseau devient beaucoup plus robuste au choix des poids initiaux.
  5. Suppression du besoin de Dropout dans de nombreuses architectures : les couches Batch Normalization fournissent déjà une forme de régularisation.

Implémentation Python avec PyTorch

1. Implémentation from scratch

Voici une implémentation complète de BatchNorm2d from scratch en PyTorch, incluant la gestion des statistiques running pour la phase d’inférence :

import torch
import torch.nn as nn
import torch.nn.functional as F


class BatchNorm2dFromScratch(nn.Module):
    """
    Implémentation from scratch de BatchNorm2d.
    Reproduit le comportement de nn.BatchNorm2d de PyTorch.
    """

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine

        # Paramètres apprenables gamma et beta
        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_features))   # gamma
            self.bias = nn.Parameter(torch.zeros(num_features))    # beta
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        # Statistiques running pour l'inférence
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0))

    def forward(self, x):
        # x a la forme (N, C, H, W)
        if self.training:
            # Calcul des statistiques du batch
            # On moyenne sur les dimensions N, H, W pour chaque canal C
            batch_mean = x.mean(dim=(0, 2, 3))       # shape: (C,)
            batch_var = x.var(dim=(0, 2, 3), unbiased=False)  # shape: (C,)

            # Mise à jour des statistiques running
            self.running_mean = (
                (1 - self.momentum) * self.running_mean
                + self.momentum * batch_mean.detach()
            )
            self.running_var = (
                (1 - self.momentum) * self.running_var
                + self.momentum * batch_var.detach()
            )
            self.num_batches_tracked += 1

            # Normalisation
            x_normalized = (x - batch_mean.view(1, -1, 1, 1)) / torch.sqrt(
                batch_var.view(1, -1, 1, 1) + self.eps
            )
        else:
            # En mode inférence : utilisation des statistiques running
            x_normalized = (x - self.running_mean.view(1, -1, 1, 1)) / torch.sqrt(
                self.running_var.view(1, -1, 1, 1) + self.eps
            )

        # Transformation affine
        if self.affine:
            x_normalized = self.weight.view(1, -1, 1, 1) * x_normalized
            x_normalized = x_normalized + self.bias.view(1, -1, 1, 1)

        return x_normalized

2. Comparaison avec nn.BatchNorm2d de PyTorch

Vérifions que notre implémentation produit des résultats identiques à celle de PyTorch :

import torch
import torch.nn as nn

# Création d'une entrée de test
torch.manual_seed(42)
x = torch.randn(8, 16, 32, 32)  # (N, C, H, W)

# Notre implémentation
bn_custom = BatchNorm2dFromScratch(num_features=16)
bn_custom.train()
output_custom = bn_custom(x)

# Implémentation PyTorch
bn_pytorch = nn.BatchNorm2d(num_features=16)
bn_pytorch.train()
output_pytorch = bn_pytorch(x)

print(f"Forme de sortie custom : {output_custom.shape}")
print(f"Forme de sortie PyTorch : {output_pytorch.shape}")
print(f"Différence maximale : {torch.max(torch.abs(output_custom - output_pytorch)):.6f}")
print(f"Moyenne custom : {output_custom.mean():.6f}")
print(f"Moyenne PyTorch : {output_pytorch.mean():.6f}")

Les deux implémentations produisent des résultats numériquement très proches, ce qui valide notre compréhension du mécanisme interne du Batch Normalization.

3. Entraînement d’un ResNet avec et sans Batch Normalization

Comparons maintenant l’impact du Batch Normalization sur l’entraînement d’un ResNet-18 simplifié :

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt


class ResidualBlock(nn.Module):
    """Bloc résiduel optionnellement avec Batch Normalization."""

    def __init__(self, in_channels, out_channels, use_bn=True, stride=1):
        super().__init__()
        self.use_bn = use_bn

        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3,
            stride=stride, padding=1, bias=not use_bn
        )
        self.bn1 = nn.BatchNorm2d(out_channels) if use_bn else None
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=not use_bn
        )
        self.bn2 = nn.BatchNorm2d(out_channels) if use_bn else None

        # Skip connection avec convolution si nécessaire
        if in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1,
                                  stride=stride, bias=not use_bn)
            self.skip_bn = nn.BatchNorm2d(out_channels) if use_bn else None
        else:
            self.skip = None
            self.skip_bn = None

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        if self.use_bn:
            out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        if self.use_bn:
            out = self.bn2(out)

        if self.skip is not None:
            identity = self.skip(x)
            if self.skip_bn is not None:
                identity = self.skip_bn(identity)

        out += identity
        out = self.relu(out)
        return out


class SimpleResNet(nn.Module):
    """ResNet simplifié pour CIFAR-10."""

    def __init__(self, use_bn=True, num_classes=10):
        super().__init__()
        self.conv_in = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=not use_bn)
        self.bn_in = nn.BatchNorm2d(64) if use_bn else None
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = ResidualBlock(64, 64, use_bn=use_bn)
        self.layer2 = ResidualBlock(64, 128, use_bn=use_bn, stride=2)
        self.layer3 = ResidualBlock(128, 256, use_bn=use_bn, stride=2)
        self.layer4 = ResidualBlock(256, 512, use_bn=use_bn, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv_in(x)
        if self.bn_in:
            x = self.bn_in(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def train_and_compare():
    """Entraîne deux modèles et compare les courbes de convergence."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Transformations CIFAR-10
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                           (0.2470, 0.2435, 0.2616))
    ])

    trainset = datasets.CIFAR10(root='./data', train=True,
                                download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                              shuffle=True, num_workers=2)

    epochs = 30
    criterion = nn.CrossEntropyLoss()

    results = {"avec_bn": [], "sans_bn": []}

    for config_name, use_bn in [("avec_bn", True), ("sans_bn", False)]:
        print(f"\n{'='*50}")
        print(f"Entraînement {config_name}")
        print(f"{'='*50}")

        model = SimpleResNet(use_bn=use_bn).to(device)
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9,
                              weight_decay=5e-4)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

        for epoch in range(epochs):
            model.train()
            running_loss = 0.0
            correct = 0
            total = 0

            for inputs, targets in trainloader:
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

            scheduler.step()
            epoch_loss = running_loss / len(trainloader)
            epoch_acc = 100.0 * correct / total
            results[config_name].append(epoch_loss)

            print(f"  Époque {epoch+1:2d}/{epochs} — "
                  f"Perte: {epoch_loss:.4f} — Précision: {epoch_acc:.2f}%")

    # Visualisation des courbes de convergence
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(results["avec_bn"], 'b-o', label='Avec Batch Norm', markersize=4)
    plt.plot(results["sans_bn"], 'r-s', label='Sans Batch Norm', markersize=4)
    plt.xlabel('Époque')
    plt.ylabel("Perte d'entraînement")
    plt.title('Courbes de convergence : avec vs sans Batch Normalization')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.subplot(1, 2, 2)
    plt.plot([0.1 if e < 10 else 0.01 if e < 20 else 0.001
              for e in range(epochs)], 'g--', label='Learning Rate')
    plt.xlabel('Époque')
    plt.ylabel("Taux d'apprentissage")
    plt.title("Planification du taux d'apprentissage")
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('batch_norm_convergence.png', dpi=150)
    print("\nGraphique sauvegardé : batch_norm_convergence.png")


if __name__ == "__main__":
    train_and_compare()

4. Ce que révèlent les courbes de convergence

En pratique, on observe systématiquement que le modèle avec Batch Normalization :

  • Converge plus rapidement : la perte diminue plus vite dès les premières époques.
  • Atteint une meilleure précision finale : la régularisation implicite améliore la généralisation.
  • Supporte des learning rates plus élevés : sans BN, un learning rate de 0,1 provoque souvent la divergence ; avec BN, c’est un choix standard et stable.
  • Présente une évolution de gradient plus lisse : les mises à jour sont plus régulières, sans les oscillations violentes caractéristiques des réseaux profonds non normalisés.

Hyperparamètres du Batch Normalization

Le Batch Normalization introduit quelques hyperparamètres qu’il est important de bien comprendre :

Epsilon (ε) — Stabilité numérique

  • Valeur par défaut : 1e-5 (PyTorch), 1e-3 (TensorFlow/Keras)
  • Rôle : Évite la division par zéro lorsque la variance du batch est nulle ou quasi-nulle.
  • Impact : En général, on ne touche pas à ce paramètre. Des valeurs trop grandes perturbent la normalisation, des valeurs trop petites risquent l’instabilité numérique.

Momentum — Vitesse d’adaptation des statistiques running

  • Valeur par défaut : 0,1 dans PyTorch (notez que c’est l’inverse de TensorFlow qui utilise 0,99)
  • Rôle : Contrôle le poids relatif des nouvelles statistiques du batch par rapport aux statistiques historiques : running_new = (1 - momentum) × running_old + momentum × batch_stat
  • Impact : Un momentum faible (0,1) donne plus de poids à l’historique — les statistiques running sont très lisses. Un momentum élevé rend les running stats plus réactives mais potentiellement plus bruitées.

Affine — Activation des paramètres apprenables

  • Valeur par défaut : True
  • Rôle : Détermine si les paramètres γ (poids) et β (biais) sont apprenables.
  • Impact : Dans la grande majorité des cas, on laisse affine=True. Mettre affine=False force la normalisation stricte à moyenne nulle et variance unité, ce qui peut limiter la capacité du modèle.

Tailles de batch recommandées

Le Batch Normalization dépend fortement de la taille du batch. Un batch trop petit (inférieur à 16-32) produit des estimations de moyenne et variance trop bruitées. Pour les très petits batches, on recommande des alternatives comme Group Normalization ou Layer Normalization.

Avantages et Limites

Avantages

  1. Accélération significative de l’entraînement : convergence atteinte en 5 à 10 fois moins d’époques.
  2. Stabilisation des gradients : réduit considérablement les problèmes de gradients explosifs ou disparaissants.
  3. Permet des learning rates plus élevés : les gradients mieux conditionnés autorisent des pas d’optimisation plus grands.
  4. Régularisation implicite : le bruit des statistiques de mini-batch aide à la généralisation.
  5. Réduit la sensibilité à l’initialisation : le réseau est plus robuste aux choix de poids initiaux.
  6. Simplifie l’architecture : réduit le besoin de Dropout dans de nombreuses architectures convolutives.

Limites

  1. Dépendance à la taille du batch : les petits batches (inférieur à 16) donnent des statistiques peu fiables — ce qui pose problème pour la détection d’objets ou le traitement vidéo où les batches sont naturellement petits.
  2. Incompatibilité avec le traitement séquentiel en RNN : le BN n’est pas trivial à appliquer aux séquences de longueur variable.
  3. Bruit pendant l’entraînement vs silence à l’inférence : la différence de comportement entre le mode entraînement (bruité par les statistiques de batch) et le mode inférence (déterministe) peut créer un léger écart de performance.
  4. Coût mémoire et computationnel : le calcul des statistiques et le stockage des paramètres running ajoutent un overhead, surtout pour les très grands modèles.
  5. Interaction avec Dropout : le BN et le Dropout ensemble peuvent parfois avoir des effets antagonistes. Les bonnes pratiques recommandent soit de les utiliser séparément, soit d’appliquer le Batch Norm avant le Dropout.

4 Cas d’usage concrets

Cas d’usage n°1 : Classification d’images avec ResNet

Le cas d’usage classique du Batch Normalization. Les architectures ResNet l’intègrent après chaque convolution, avant la fonction d’activation ReLU. Cette combinaison Conv → BN → ReLU est devenue le standard de l’industrie pour la classification d’images, permettant d’entraîner des réseaux de plus de 1000 couches (ResNet-1000+) avec une stabilité remarquable.

Cas d’usage n°2 : Réseaux génératifs (GAN)

Les GAN (Generative Adversarial Networks) sont notoirement instables à entraîner. Le Batch Normalization joue un rôle crucial pour stabiliser l’entraînement des générateurs et discriminateurs. Dans DCGAN (Deep Convolutional GAN), le BN est appliqué partout sauf sur la couche de sortie du générateur et la couche d’entrée du discriminateur, ce qui empêche l’effondrement des modes et améliore la qualité des images générées.

Cas d’usage n°3 : Transfer Learning et Fine-Tuning

Lorsqu’on fait du transfer learning (apprentissage par transfert) en reprenant un modèle pré-entraîné (comme ResNet-50 sur ImageNet), le Batch Normalization nécessite une attention particulière. Les statistiques running du modèle pré-entraîné reflètent la distribution d’ImageNet, pas celle de notre nouveau jeu de données. Deux stratégies courantes existent : geler complètement les couches BN (utiliser les stats originales) ou les remettre en mode entraînement (les recalculer sur nos données). Le choix dépend de la similarité entre les distributions.

Cas d’usage n°4 : Réseaux avec Skip Connections et architectures denses

Les architectures comme DenseNet, où chaque couche est connectée à toutes les couches suivantes, bénéficient énormément du Batch Normalization. Sans BN, les activations s’accumuleraient et exploseraient à travers les multiples chemins de skip connections. Le Batch Normalization maintient chaque activation à une échelle contrôlée, rendant ces architectures denses profondément empilables et efficaces.

Voir aussi


Laisser un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *

Ce site utilise Akismet pour réduire les indésirables. En savoir plus sur la façon dont les données de vos commentaires sont traitées.