Consistency Models : Guide complet — Génération en un Seul Pas
Résumé — Les Consistency Models, introduits par Song, Dhariwal et Chen (OpenAI, 2023), sont une approche de génération ultra-rapide qui apprend une fonction capable de mapper n’importe quel point d’une trajectoire ODE de diffusion directement vers le point final (l’échantillon généré) en un seul passage. Contrairement aux modèles de diffusion qui nécessitent 50 à 1000 étapes d’intégration, les Consistency Models peuvent générer en 1 step (instantanément) ou en quelques steps pour une qualité supérieure. C’est une forme de distillation : on entraîne le modèle à être cohérent avec lui-même le long des trajectoires de diffusion.
Principe mathématique
1. La fonction de consistance
Un Consistency Model apprend une fonction fθ(x_t, t) qui, pour n’importe quel point x_t sur une trajectoire ODE de diffusion (quelle que soit l’étape t), retourne toujours le même résultat final x_0. C’est la contrainte de consistance :
fθ(x_t, t) = fθ(x_{t'}, t') pour tous t, t' sur la même trajectoire
En particulier, fθ(x_T, T) = x_0 où T est le temps final (bruit pur) et x_0 est l’échantillon généré.
2. Entraînement par consistance
La clé de l’entraînement est de forcer la consistance entre deux points adjacents sur la même trajectoire ODE. La loss de consistance est :
L(θ) = E[||fθ(x_{t_{n+1}}, t_{n+1}) - fθ^{µ}(x_{t_n}, t_n)||²]
Où fθ^{µ} est une moyenne exponentielle des paramètres (EMA) qui sert de cible stable. On discrétise la trajectoire ODE en N temps t_1 < t_2 < … < t_N et on entraîne le modèle à être cohérent entre temps adjacents.
3. Paramétrisation
Pour satisfaire la condition de bord (f(x_0, 0) = x_0), on paramétrise :
f_θ(x_t, t) = c_{skip}(t) · x_t + c_{out}(t) · F_θ(x_t, t)
Où c_{skip}(0) = 1 et c_{out}(0) = 0 assurent que f_θ(x_0, 0) = x_0 automatiquement.
4. Génération
La génération est triviale :
– 1 step : fθ(x_T, T) où x_T ~ N(0,I) — instantané
– Few steps : on peut itérer pour améliorer la qualité
5. Distillation de modèles de diffusion
Une approche alternative est la distillation directe d’un modèle de diffusion pré-entraîné : on utilise le modèle de diffusion pour générer des trajectoires et on entraîne le consistency model à reproduire le résultat final depuis n’importe quel point intermédiaire. Cette approche donne de meilleurs résultats car elle bénéficie de la qualité du modèle de diffusion source.
Intuition
La descente de gradient dans un modèle de diffusion, c’est comme descendre une montagne de 1000 marches, une par une. C’est lent mais ça marche. Un Consistency Model, c’est comme avoir un téléphérique : peu importe où tu es sur la montagne (en haut dans le brouillard ou à mi-pente), le téléphérique te ramène toujours au même endroit en bas — en un seul voyage.
Le réseau apprend une fonction qui dit « peu importe l’étape de débruitage où tu te trouves, voici à quoi ressemble l’image finale ». En échangeant un tout petit peu de qualité contre un gain de vitesse énorme (1000 steps → 1 step), on obtient une génération quasi instantanée.
C’est comme la différence entre un GPS qui te dit « tourne à droite dans 200m, puis à gauche dans 500m » (diffusion, step par step) et un pilote automatique qui dit « la destination est là, allons-y directement » (consistency model, 1 step).
Implémentation Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ConsistencyModel(nn.Module):
def __init__(self, data_dim=2, hidden_dim=256):
super().__init__()
self.data_dim = data_dim
self.register_buffer('ema_params', None)
# Feature extraction
self.t_embed = nn.Sequential(
nn.Linear(1, hidden_dim), nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim), nn.SiLU()
)
self.x_embed = nn.Linear(data_dim, hidden_dim)
# Backbone
self.net = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim), nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),
nn.Linear(hidden_dim, data_dim)
)
def c_skip(self, t):
return torch.ones_like(t)
def c_skip_at_zero(self, t):
"""Returns c_skip(t) with c_skip(0) = 1"""
sigma = t
return 0.25 / (sigma + 0.25)
def c_out(self, t):
"""Returns c_out(t) with c_out(0) = 0"""
sigma = t
return 0.25 * sigma / (sigma + 0.25)
def forward(self, x, t):
t = t.view(-1, 1)
c_skip = self.c_skip_at_zero(t)
c_out = self.c_out(t)
F_theta = self.net(
torch.cat([self.x_embed(x), self.t_embed(t)], dim=1)
)
return c_skip * x + c_out * F_theta
class ConsistencyTrainer:
def __init__(self, model, lr=1e-3, timesteps=40):
self.model = model
self.ema_model = None
self.lr = lr
self.timesteps = timesteps
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
self.ts = torch.linspace(0.002, 1.0, timesteps)
def consistency_loss(self, x_0):
"""Loss de consistance entre temps adjacents"""
batch = x_0.size(0)
# Choisir un temps aléatoire
n = torch.randint(0, self.timesteps - 1, (batch,))
t_n = self.ts[n]
t_np1 = self.ts[n + 1]
# Ajouter du bruit pour simuler la forward process
sigma_n = t_n.view(-1, 1)
sigma_np1 = t_np1.view(-1, 1)
z = torch.randn_like(x_0)
x_tn = x_0 + sigma_n * z
x_tnp1 = x_0 + sigma_np1 * z
# Prédictions
f_n = self.model(x_tn, t_n)
f_np1 = self.model(x_tnp1, t_np1)
# Loss de consistance : les deux doivent être égaux
loss = F.mse_loss(f_n, f_np1)
return loss
def train_step(self, x_0):
self.optimizer.zero_grad()
loss = self.consistency_loss(x_0)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
# EMA update
if self.ema_model is None:
self.ema_model = self.deepcopy_model(self.model)
else:
self.update_ema(0.995)
return loss.item()
@torch.no_grad()
def sample(self, n, steps=1):
"""Generation en 1 step ou few steps"""
x = torch.randn(n, self.model.data_dim)
t = torch.ones(n)
x_gen = self.model(x, t)
return x_gen
def deepcopy_model(self, src):
dst = type(src)()
dst.load_state_dict(src.state_dict())
return dst
def update_ema(self, decay):
for p, ema_p in zip(self.model.parameters(),
self.ema_model.parameters()):
ema_p.data.mul_(decay).add_(p.data, alpha=1 - decay)
# Entraînement sur des données 2D (spirales)
def make_spiral(n=5000):
r = torch.rand(n) * 2
t = torch.rand(n) * 2 * math.pi
x = r * torch.cos(t + r)
y = r * torch.sin(t + r)
return torch.stack([x, y], dim=1)
data = make_spiral()
model = ConsistencyModel(data_dim=2, hidden_dim=256)
trainer = ConsistencyTrainer(model, lr=1e-3, timesteps=40)
for epoch in range(500):
idx = torch.randint(0, data.size(0), (512,))
batch = data[idx]
loss = trainer.train_step(batch)
if epoch % 50 == 0:
samples = trainer.sample(100)
print(f'Epoch {epoch} | Loss: {loss:.4f}')
# Generation instantanee
samples_1step = trainer.sample(100, steps=1)
print(f'Generated {samples_1step.size(0)} samples in 1 step')
Hyperparamètres
| Hyperparamètre | Valeur typique | Description |
|---|---|---|
| timesteps | 18-404 | Nombre de temps discrétisés (plus = qualité mais plus lent) |
| ema_decay | 0.995 | Taux de décroissance EMA pour la cible stable |
| lr | 1e-3 | Learning rate Adam |
| sigma_min | 0.002 | Bruit minimum (condition de bord à x_0) |
Avantages
- Génération ultra-rapide : 1 step contre 50-1000 pour la diffusion conventionnelle, soit un accélération de 50x à 1000x.
- Qualité préservée : Avec les Consistency Models avancés (CD), la qualité atteint 95% de celle de la diffusion pour une fraction du coût.
- Échelle flexible : On peut choisir entre 1 step (rapide) et quelques steps (équilibré) selon le besoin.
- Compatibilité : Peut être entraîné par distillation à partir de n’importe quel modèle de diffusion existant.
- Pas de planification de bruit : Pas besoin de noise schedule sophistiqué contrairement à la diffusion.
Limites
- Entraînement instable : La loss de consistance entre points adjacents peut être difficile à optimiser.
- Qualité inférieure en 1 step : La génération en un seul pas perd des détails fins par rapport à la diffusion multi-steps.
- Besoin d’un modèle source : La meilleure approche nécessite un modèle de diffusion pré-entraîné coûteux à produire.
- Recherche jeune : Moins mature et testée que les modèles de diffusion ou Flow Matching.
4 cas d’usage concrets
1. Génération d’images temps réel
Pour les applications interactives comme les éditeurs d’images IA, la génération doit être instantanée. Les Consistency Models permettent de générer des images en 1-4 steps au lieu des 50-100 steps habituels, rendant possible l’édition en temps réel.
2. Synthèse vocale faible latence
Dans les assistants vocaux, la latence est cruciale. Les CMs réduisent le temps de génération audio de plusieurs secondes à quelques millisecondes, rendant la conversation plus naturelle.
3. Design de molécules haute throughput
En chimie computationnelle, il faut générer et évaluer des millions de molécules candidates. Les CMs permettent une génération ultra-rapide de structures moléculaires 3D valides.
4. Data augmentation en ligne
Les Consistency Models peuvent générer des données synthétiques à la volée pendant l’entraînement d’autres modèles, sans le coût prohibitif de la diffusion conventionnelle.
Voir aussi
- Structures de données avancées en Python : Dictionnaires et Ensembles
- Maîtriser l’Échange de Compteurs en Python : Guide Complet et Astuces

