Consistency Models : Guide Complet — Génération en un Seul Pas

Consistency Models : Guide Complet — Génération en un Seul Pas

Consistency Models : Guide complet — Génération en un Seul Pas

Résumé — Les Consistency Models, introduits par Song, Dhariwal et Chen (OpenAI, 2023), sont une approche de génération ultra-rapide qui apprend une fonction capable de mapper n’importe quel point d’une trajectoire ODE de diffusion directement vers le point final (l’échantillon généré) en un seul passage. Contrairement aux modèles de diffusion qui nécessitent 50 à 1000 étapes d’intégration, les Consistency Models peuvent générer en 1 step (instantanément) ou en quelques steps pour une qualité supérieure. C’est une forme de distillation : on entraîne le modèle à être cohérent avec lui-même le long des trajectoires de diffusion.


Principe mathématique

1. La fonction de consistance

Un Consistency Model apprend une fonction fθ(x_t, t) qui, pour n’importe quel point x_t sur une trajectoire ODE de diffusion (quelle que soit l’étape t), retourne toujours le même résultat final x_0. C’est la contrainte de consistance :

fθ(x_t, t) = fθ(x_{t'}, t') pour tous t, t' sur la même trajectoire

En particulier, fθ(x_T, T) = x_0 où T est le temps final (bruit pur) et x_0 est l’échantillon généré.

2. Entraînement par consistance

La clé de l’entraînement est de forcer la consistance entre deux points adjacents sur la même trajectoire ODE. La loss de consistance est :

L(θ) = E[||fθ(x_{t_{n+1}}, t_{n+1}) - fθ^{µ}(x_{t_n}, t_n)||²]

Où fθ^{µ} est une moyenne exponentielle des paramètres (EMA) qui sert de cible stable. On discrétise la trajectoire ODE en N temps t_1 < t_2 < … < t_N et on entraîne le modèle à être cohérent entre temps adjacents.

3. Paramétrisation

Pour satisfaire la condition de bord (f(x_0, 0) = x_0), on paramétrise :

f_θ(x_t, t) = c_{skip}(t) · x_t + c_{out}(t) · F_θ(x_t, t)

Où c_{skip}(0) = 1 et c_{out}(0) = 0 assurent que f_θ(x_0, 0) = x_0 automatiquement.

4. Génération

La génération est triviale :
1 step : fθ(x_T, T) où x_T ~ N(0,I) — instantané
Few steps : on peut itérer pour améliorer la qualité

5. Distillation de modèles de diffusion

Une approche alternative est la distillation directe d’un modèle de diffusion pré-entraîné : on utilise le modèle de diffusion pour générer des trajectoires et on entraîne le consistency model à reproduire le résultat final depuis n’importe quel point intermédiaire. Cette approche donne de meilleurs résultats car elle bénéficie de la qualité du modèle de diffusion source.


Intuition

La descente de gradient dans un modèle de diffusion, c’est comme descendre une montagne de 1000 marches, une par une. C’est lent mais ça marche. Un Consistency Model, c’est comme avoir un téléphérique : peu importe où tu es sur la montagne (en haut dans le brouillard ou à mi-pente), le téléphérique te ramène toujours au même endroit en bas — en un seul voyage.

Le réseau apprend une fonction qui dit « peu importe l’étape de débruitage où tu te trouves, voici à quoi ressemble l’image finale ». En échangeant un tout petit peu de qualité contre un gain de vitesse énorme (1000 steps → 1 step), on obtient une génération quasi instantanée.

C’est comme la différence entre un GPS qui te dit « tourne à droite dans 200m, puis à gauche dans 500m » (diffusion, step par step) et un pilote automatique qui dit « la destination est là, allons-y directement » (consistency model, 1 step).


Implémentation Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class ConsistencyModel(nn.Module):
    def __init__(self, data_dim=2, hidden_dim=256):
        super().__init__()
        self.data_dim = data_dim
        self.register_buffer('ema_params', None)

        # Feature extraction
        self.t_embed = nn.Sequential(
            nn.Linear(1, hidden_dim), nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.SiLU()
        )
        self.x_embed = nn.Linear(data_dim, hidden_dim)

        # Backbone
        self.net = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim), nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),
            nn.Linear(hidden_dim, data_dim)
        )

    def c_skip(self, t):
        return torch.ones_like(t)

    def c_skip_at_zero(self, t):
        """Returns c_skip(t) with c_skip(0) = 1"""
        sigma = t
        return 0.25 / (sigma + 0.25)

    def c_out(self, t):
        """Returns c_out(t) with c_out(0) = 0"""
        sigma = t
        return 0.25 * sigma / (sigma + 0.25)

    def forward(self, x, t):
        t = t.view(-1, 1)
        c_skip = self.c_skip_at_zero(t)
        c_out = self.c_out(t)
        F_theta = self.net(
            torch.cat([self.x_embed(x), self.t_embed(t)], dim=1)
        )
        return c_skip * x + c_out * F_theta


class ConsistencyTrainer:
    def __init__(self, model, lr=1e-3, timesteps=40):
        self.model = model
        self.ema_model = None
        self.lr = lr
        self.timesteps = timesteps
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.ts = torch.linspace(0.002, 1.0, timesteps)

    def consistency_loss(self, x_0):
        """Loss de consistance entre temps adjacents"""
        batch = x_0.size(0)
        # Choisir un temps aléatoire
        n = torch.randint(0, self.timesteps - 1, (batch,))
        t_n = self.ts[n]
        t_np1 = self.ts[n + 1]

        # Ajouter du bruit pour simuler la forward process
        sigma_n = t_n.view(-1, 1)
        sigma_np1 = t_np1.view(-1, 1)

        z = torch.randn_like(x_0)
        x_tn = x_0 + sigma_n * z
        x_tnp1 = x_0 + sigma_np1 * z

        # Prédictions
        f_n = self.model(x_tn, t_n)
        f_np1 = self.model(x_tnp1, t_np1)

        # Loss de consistance : les deux doivent être égaux
        loss = F.mse_loss(f_n, f_np1)
        return loss

    def train_step(self, x_0):
        self.optimizer.zero_grad()
        loss = self.consistency_loss(x_0)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

        # EMA update
        if self.ema_model is None:
            self.ema_model = self.deepcopy_model(self.model)
        else:
            self.update_ema(0.995)

        return loss.item()

    @torch.no_grad()
    def sample(self, n, steps=1):
        """Generation en 1 step ou few steps"""
        x = torch.randn(n, self.model.data_dim)
        t = torch.ones(n)
        x_gen = self.model(x, t)
        return x_gen

    def deepcopy_model(self, src):
        dst = type(src)()
        dst.load_state_dict(src.state_dict())
        return dst

    def update_ema(self, decay):
        for p, ema_p in zip(self.model.parameters(),
                            self.ema_model.parameters()):
            ema_p.data.mul_(decay).add_(p.data, alpha=1 - decay)


# Entraînement sur des données 2D (spirales)
def make_spiral(n=5000):
    r = torch.rand(n) * 2
    t = torch.rand(n) * 2 * math.pi
    x = r * torch.cos(t + r)
    y = r * torch.sin(t + r)
    return torch.stack([x, y], dim=1)

data = make_spiral()
model = ConsistencyModel(data_dim=2, hidden_dim=256)
trainer = ConsistencyTrainer(model, lr=1e-3, timesteps=40)

for epoch in range(500):
    idx = torch.randint(0, data.size(0), (512,))
    batch = data[idx]
    loss = trainer.train_step(batch)
    if epoch % 50 == 0:
        samples = trainer.sample(100)
        print(f'Epoch {epoch} | Loss: {loss:.4f}')

# Generation instantanee
samples_1step = trainer.sample(100, steps=1)
print(f'Generated {samples_1step.size(0)} samples in 1 step')

Hyperparamètres

Hyperparamètre Valeur typique Description
timesteps 18-404 Nombre de temps discrétisés (plus = qualité mais plus lent)
ema_decay 0.995 Taux de décroissance EMA pour la cible stable
lr 1e-3 Learning rate Adam
sigma_min 0.002 Bruit minimum (condition de bord à x_0)

Avantages

  1. Génération ultra-rapide : 1 step contre 50-1000 pour la diffusion conventionnelle, soit un accélération de 50x à 1000x.
  2. Qualité préservée : Avec les Consistency Models avancés (CD), la qualité atteint 95% de celle de la diffusion pour une fraction du coût.
  3. Échelle flexible : On peut choisir entre 1 step (rapide) et quelques steps (équilibré) selon le besoin.
  4. Compatibilité : Peut être entraîné par distillation à partir de n’importe quel modèle de diffusion existant.
  5. Pas de planification de bruit : Pas besoin de noise schedule sophistiqué contrairement à la diffusion.

Limites

  1. Entraînement instable : La loss de consistance entre points adjacents peut être difficile à optimiser.
  2. Qualité inférieure en 1 step : La génération en un seul pas perd des détails fins par rapport à la diffusion multi-steps.
  3. Besoin d’un modèle source : La meilleure approche nécessite un modèle de diffusion pré-entraîné coûteux à produire.
  4. Recherche jeune : Moins mature et testée que les modèles de diffusion ou Flow Matching.

4 cas d’usage concrets

1. Génération d’images temps réel

Pour les applications interactives comme les éditeurs d’images IA, la génération doit être instantanée. Les Consistency Models permettent de générer des images en 1-4 steps au lieu des 50-100 steps habituels, rendant possible l’édition en temps réel.

2. Synthèse vocale faible latence

Dans les assistants vocaux, la latence est cruciale. Les CMs réduisent le temps de génération audio de plusieurs secondes à quelques millisecondes, rendant la conversation plus naturelle.

3. Design de molécules haute throughput

En chimie computationnelle, il faut générer et évaluer des millions de molécules candidates. Les CMs permettent une génération ultra-rapide de structures moléculaires 3D valides.

4. Data augmentation en ligne

Les Consistency Models peuvent générer des données synthétiques à la volée pendant l’entraînement d’autres modèles, sans le coût prohibitif de la diffusion conventionnelle.


Voir aussi


Laisser un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *

Ce site utilise Akismet pour réduire les indésirables. En savoir plus sur la façon dont les données de vos commentaires sont traitées.