Gumbel-Softmax : Guide complet — Échantillonnage Différentiable Discret
Résumé — Le Gumbel-Softmax est une astuce qui permet d’entraîner des modèles de machine learning qui impliquent des échantillonnages discrets via un mécanisme différentiable. Proposé par Jang, Gu et Poole en 2016 et Maddison et al. la même année, il résoud un problème fondamental: on ne peut pas backpropager à travers une opération d’échantillonnage catégoriel. Le Gumbel-Softmax remplace cette opération par une approximation continue et différentiable, permettant l’entraînement de bout en bout de modèles avec des variables latentes discrètes.
Principe mathématique
1. Le problème: échantillonnage non différentiable
Pour échantillonner d’une distribution catégorielle avec des probabilités π = [π₁, π₂, …, π_K], on utilise traditionnellement:
k = argmax_i (log(π_i) + g_i) où g_i ~ Gumbel(0, 1)
Où g_i = -log(-log(u_i)), u_i ~ Uniform(0, 1). C’est le Gumbel-Max trick, qui donne des échantillons exacts. Le problème: argmax n’a pas de gradient.
2. Relaxation par softmax
Au lieu de argmax, on utilise softmax avec un paramètre de température τ (tau) :
y_i = exp((log(π_i) + g_i) / τ) / somme_j exp((log(π_j) + g_j) / τ)
- Quand τ → 0 : y tend vers one-hot (comme argmax)
- Quand τ → ∞ : y tend vers la distribution uniforme
- Pour τ intermédiaire : y est une approximation douce du one-hot
Et le plus important : softmax est différentiable. Les gradients de L par rapport à log(π) passent à travers le softmax:
dL/d(log(π_i)) = somme_j dL/dy_j · dy_j/d(log(π_i))
3. Hard Gumbel-Softmax
Pour obtenir des échantillons one-hot tout en passant les gradients en forward, on utilise le straight-through estimator :
# Forward: one-hot
y_hard = one_hot(argmax(y))
# Backward: gradients de y (soft)
y = y_hard - y.detach() + y
Cela donne des sorties discrètes en forward mode (utile pour la génération de texte) tout en ayant les gradients corrects en backward.
4. Annealing de la température
En pratique, on démarre avec une température élevée (τ = 1.0, distribution douce) et on la réduit progressivement pendant l’entraînement:
τ(t) = max(τ_min, τ_0 · exp(-rate · t))
Cela permet une exploration initiale (softmax doux) puis une convergence vers des décisions discrètes (one-hot).
Intuition
Imaginez que vous devez choisir entre plusieurs chemins dans une forêt, mais vous voulez pouvoir revenir en arrière et explorer d’autres options.
Le Gumbel-Softmax, c’est comme un GPS intelligent qui choisit un chemin mais laisse toujours une petite porte ouverte pour revenir en arrière. La température τ contrôle ce compromis :
- Haute température (τ = 5.0) : le GPS est indécis, il prend tous les chemins à la fois avec des poids similaires. C’est l’exploration maximale.
- Basse température (τ = 0.1) : le GPS est fermement engagé sur un seul chemin. C’est l’exploitation.
En démarrant doux et en durcissant progressivement, l’explorateur (le modèle) explore d’abord toutes les options puis se spécialise sur les meilleures.
Implémentation Python
1. Gumbel-Softmax from scratch
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def sample_gumbel(shape, eps=1e-20):
"""Échantillonner de Gumbel(0, 1)."""
U = torch.rand(shape)
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax(logits, temperature=1.0, hard=False, dim=-1):
"""Échantillonnage Gumbel-Softmax différentiable."""
gumbel_noise = sample_gumbel(logits.shape)
y = logits + gumbel_noise
y_soft = F.softmax(y / temperature, dim=dim)
if hard:
index = y_soft.argmax(dim=dim, keepdim=True)
y_hard = F.one_hot(index, num_classes=y_soft.shape[dim]).float().squeeze(dim)
y = y_hard - y_soft.detach() + y_soft
return y
return y_soft
# Démonstration: échantillonnage catégoriel différentiable
logits = torch.tensor([[2.0, 1.0, 0.1],
[0.5, 3.0, 1.0]], requires_grad=True)
for tau in [5.0, 1.0, 0.5, 0.1]:
sample = gumbel_softmax(logits, temperature=tau, hard=False)
print(f'tau={tau}: {sample[0].detach().numpy().round(3)}')
loss = sample.sum()
loss.backward()
print(f' Gradient on logits[0]: {logits.grad[0].numpy().round(3)}')
logits.grad.zero_()
2. VAE avec variables latentes discrètes via Gumbel-Softmax
class DiscreteVAE(nn.Module):
"""VAE avec espace latent catégoriel via Gumbel-Softmax."""
def __init__(self, input_dim=784, latent_categories=32, latent_dim=10):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256), nn.ReLU(),
nn.Linear(256, latent_categories * latent_dim)
)
self.decoder = nn.Sequential(
nn.Linear(latent_categories * latent_dim, 256), nn.ReLU(),
nn.Linear(256, input_dim), nn.Sigmoid()
)
self.latent_categories = latent_categories
self.latent_dim = latent_dim
def encode(self, x):
h = self.encoder(x)
h = h.view(-1, self.latent_dim, self.latent_categories)
return F.log_softmax(h, dim=2)
def sample_latent(self, log_probs, temperature, hard=False):
samples = []
for i in range(self.latent_dim):
log_p = log_probs[:, i, :]
s = gumbel_softmax(log_p, temperature=temperature, hard=hard, dim=1)
samples.append(s)
return torch.stack(samples, dim=1)
def forward(self, x, temperature=1.0, hard=True):
log_probs = self.encode(x)
z = self.sample_latent(log_probs, temperature, hard=hard)
x_recon = self.decode(z.view(z.size(0), -1))
return x_recon, log_probs
# Entraînement avec annealing
model = DiscreteVAE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(50):
temperature = max(0.1, 1.0 * np.exp(-0.05 * epoch))
model.train()
for batch in dataloader:
x_recon, log_probs = model(batch, temperature=temperature, hard=True)
recon_loss = F.binary_cross_entropy(x_recon, batch.view(batch.size(0), -1), reduction='sum')
q = F.softmax(log_probs, dim=2)
kl = (q * (log_probs - np.log(1.0 / model.latent_categories))).sum()
loss = recon_loss + kl
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch {epoch} | τ: {temperature:.3f}')
Hyperparamètres
| Hyperparamètre | Valeur typique | Description |
|---|---|---|
| temperature_init (τ₀) | 1.0 | Température initiale (distribution douce) |
| temperature_min (τ_min) | 0.1 | Température minimale (approximation du one-hot) |
| temperature_anneal_rate | 0.01-0.1 | Vitesse de réduction de la température |
| hard | True/False | Si True, sorties one-hot en forward (straight-through) |
Avantages
- Différentiable : Permet l’entraînement par gradient descent de modèles avec variables latentes discrètes, impossible auparavant sans REINFORCE.
- Plus stable que REINFORCE : Variance beaucoup plus faible que les estimateurs de score function (REINFORCE/RELUAX).
- Contrôle exploration/exploitation : La température offre un mécanisme naturel pour passer de l’exploration à l’exploitation.
- Simple à implémenter : PyTorch fournit
F.gumbel_softmax()nativement. Quelques lignes suffisent.
Limites
- Biais du gradient : Le gradient est biaisé (surtout avec hard=True). Pour une faible température, le biais est important.
- Sensibilité à la température : Un annealing trop rapide peut bloquer le modèle dans des solutions sous-optimales.
- Qualité vs REINFORCE : Dans certains cas, REINFORCE avec baseline peut donner de meilleurs résultats finals malgré une variance plus élevée.
4 cas d’usage concrets
1. VAE avec espace latent discret
Contrairement aux VAEs classiques avec latents continus (Gaussiens), un VAE avec latents discrets apprend un espace latent structuré en catégories. Chaque dimension latente peut représenter un attribut discret comme “le chiffre est écrit à la main ou tapé” sur MNIST.
2. Génération de texte avec sélection de mots différentiable
Dans les modèles de génération de texte, Gumbel-Softmax permet de sélectionner des mots de manière différentiable pendant l’entraînement. Cela permet d’optimiser directement des métriques non différentiables comme BLEU ou ROUGE.
3. Réseau neuronal avec attention sélective
Plutôt que d’attention soft sur tous les tokens, on peut utiliser Gumbel-Softmax pour sélectionner de manière différentiable un sous-ensemble de tokens. Cela crée des modèles sparse qui ne lisent que les parties pertinentes du texte, économisant du calcul.
4. Quantisation de neurones
Dans la compression de modèles, on peut utiliser Gumbel-Softmax pour apprendre une quantification optimale des poids: au lieu de forcer les poids à des valeurs discrètes, on apprend les centroids de quantification de manière différentiable.
Conclusion
Le Gumbel-Softmax est une contribution fondamentale au deep learning: il a ouvert la porte à l’entraînement de modèles avec des décisions discrètes, un domaine précédemment dominé par des méthodes à haute variance comme REINFORCE.
Bien que le biais du gradient reste un défi, l’annealing progressif de la température et les variantes améliorées ont rendu cette technique suffisamment fiable pour être utilisée dans des applications réelles de génération de texte, de quantification et d’attention sélective.
Voir aussi
- Calculer la Longueur de l’Union de Segments en Python: Guide Complet et Code Optimisé
- Découvrez les Secrets du Pandigital Prime avec Python : Guide Complet et Astuces de Programmation
- #135 VAE — Le VAE classique avec latents continus, précurseur du VAE discret avec Gumbel-Softmax.
- #073 MLP — L’architecture de base utilisée dans les encodeurs/décodeurs.
- #098 Transformer — L’architecture où Gumbel-Softmax peut être utilisé pour l’attention sélective.
Comparaison avec les alternatives
Gumbel-Softmax vs REINFORCE
| Critère | Gumbel-Softmax | REINFORCE (RELUAX) |
|---|---|---|
| Variance du gradient | Faible | Très élevée |
| Biais | Oui (approximation) | Non (estimateur non biaisé) |
| Support one-hot | Oui (straight-through) | Non nativement |
| Stabilité entraînement | Bonne (annealing contrôle) | Délicate (baseline nécessaire) |
| Performance finale | Très bonne sur grands espaces | Parfois meilleure sur petits espaces |
Gumbel-Softmax vs Relaxed Bernoulli
La version binaire du Gumbel-Softmax (deux catégories) est appelée Relaxed Bernoulli:
y = sigmoid((log(π/(1-π)) + g₁ - g₂) / τ)
C’est mathématiquement équivalent au Gumbel-Softmax avec K=2, mais plus efficace car on n’a besoin que d’un seul échantillon Gumbel au lieu de deux.
Extensions et variantes
1. Gumbel-Softmax avec biais corrigé
Pour réduire le biais du gradient, on peut combiner Gumbel-Softmax avec un terme de correction:
∇L_corrected = ∇L_gs + β · (∇L_reinforce - ∇L_gs_baseline)
Où β augmente progressivement pour compenser le biais restant.
2. Categorical Reparameterization avec Gompertz
Une alternative au Gumbel est d’utiliser une relaxation basée sur la distribution de Gompertz, qui offre des gradients plus stables pour certaines architectures.
3. Straight-Through amélioré
Le straight-through estimator standard (y_hard – y_soft.detach() + y_soft) peut être amélioré avec une version scale-aware:
y = scale_factor · (y_hard - y_soft).detach() + y_soft
Où scale_factor est adapté dynamiquement pour compenser la différence de magnitude entre les gradients soft et hard.

