TD3 : Guide Complet — Twin Delayed Deep Deterministic Policy Gradient

TD3 : Guide Complet — Twin Delayed Deep Deterministic Policy Gradient

TD3 : Guide complet — Twin Delayed Deep Deterministic Policy Gradient

Résumé — Le TD3 (Twin Delayed Deep Deterministic Policy Gradient) est un algorithme de reinforcement learning proposé par Fujimoto et al. en 2018 comme amélioration majeure du DDPG. TD3 corrige trois failles fondamentales du DDPG qui causaient une surestimation systématique des Q-values : il utilise deux critic networks et prend le minimum pour la target (double Q-learning délayé), il met à jour l’actor moins fréquemment que les critics (delayed policy update, tous les 2 steps), et il ajoute un bruit clipé aux actions cibles (target policy smoothing). Ces trois corrections rendent TD3 la référence pour les problèmes d’action continue, surpassant significativement le DDPG original sur tous les benchmarks MuJoCo.


Principe mathématique

1. Les trois failles du DDPG

Le DDPG souffre de trois problèmes majeurs:

  1. Surestimation des Q-values : Un seul critic a tendance à surestimer les valeurs Q, ce qui se propage dans les mises à jour et dégrade les performances. Comme le Q-learning classique, le target max(Q(s’,a’)) introduit un biais positif.
  2. Mise à jour trop fréquente de l’actor : L’actor est mis à jour à chaque step alors que le critic n’a pas encore convergé. Une politique mise à jour sur un critic instable est comme un étudiant qui change de méthode d’apprentissage tous les jours.
  3. Variance élevée des targets : Les actions cibles sont déterministes, rendant la Q-estimation fragile et sur-confiante.

2. Correction 1 : Double Q-Learning (Twin Critics)

TD3 maintient deux critic networks Q_1 et Q_2 avec leurs propres paramètres. Pour calculer la target, on utilise le minimum des deux :

y = r + gamma · min(Q_1'(s', a'), Q_2'(s', a'))

Le minimum des deux critics élimine le biais de surestimation : si l’un surestime, l’autre compense. Les deux critics sont entraînés indépendamment sur les mêmes données.

3. Correction 2 : Delayed Policy Update

L’actor et les cibles ne sont mis à jour que tous les d steps (d=2 typiquement) :

Si t % d == 0:
    mu_theta ← mu_theta + ∇_mu J · lr
    theta' ← tau · theta + (1-tau) · theta'  (soft update)

Pendant ce temps, les two critics sont mis à jour à chaque step. Cela permet au critic de converger correctement avant d’être utilisé pour guider la politique.

4. Correction 3 : Target Policy Smoothing

On ajoute un bruit clipé aux actions cibles pour régulariser la Q-estimation :

a' = mu'(s') + clip(epsilon, -c, c)
où epsilon ~ N(0, sigma)

Ce bruit a deux effets bénéfiques : 1) il empêche l’agent d’exploiter des imperfections d’un seul critic, 2) il lisse la Q-function autour des actions, rendant la politique plus robuste aux perturbations.

5. Algorithme complet

Initialiser theta, phi_1, phi_2, theta_target, phi_1_target, phi_2_target
Replay buffer B
Pour chaque episode:
  Observer s_0
  Pour chaque step t:
    1. Choisir a_t = mu_theta(s_t) + N(0, sigma_explore)
    2. Observer r_t, s_{t+1}, done
    3. Stocker (s_t, a_t, r_t, s_{t+1}, done) dans B
    4. Sample mini-batch de B
    5. a' = mu_target(s') + clip(N(0, sigma), -c, c)
    6. y = r + gamma · (1-done) · min(Q_1_target(s', a'), Q_2_target(s', a'))
    7. Minimiser L_i = MSE(Q_i(s, a), y) pour i = 1, 2
    8. Si t % policy_delay == 0:
       a. ∇_mu J = E[∇_a Q_1(s, a) · ∇_mu mu(s)]
       b. mu_theta ← mu_theta + lr · ∇_mu J
       c. Soft update des targets (theta, phi_1, phi_2)

Intuition

Imaginez DDPG comme un élève brillant mais impulsif. Il surestime ses capacités (biais de surestimation Q), change de stratégie trop vite (actor mis à jour avant que le critic ne converge) et a une confiance aveugle en ses propres jugements (actions déterministes sans scepticisme).

TD3 est le même élève mais avec un mentor sage : deux critiques évaluent séparément la situation et le mentor garde l’avis le plus conservateur (minimum des deux Q) ; on attend avant de changer de stratégie (delayed update : on laisse le temps au critique de bien évaluer) ; et on ne croit pas aveuglément aux notes parfaites mais on ajoute un peu de scepticisme (bruit sur les actions cibles). Résultat : l’élève prend des décisions plus saines et performe mieux en examen.


Implémentation Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy


class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, action_dim), nn.Tanh()
        )
        self.max_action = max_action

    def forward(self, x):
        return self.max_action * self.net(x)


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, s, a):
        return self.net(torch.cat([s, a], dim=1)).squeeze(-1)


class TD3:
    def __init__(self, state_dim, action_dim, max_action, lr=3e-4, gamma=0.99,
                 tau=0.005, policy_delay=2, policy_noise=0.2, noise_clip=0.5):
        self.gamma = gamma
        self.tau = tau
        self.policy_delay = policy_delay
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.total_it = 0

        self.actor = Actor(state_dim, action_dim, max_action)
        self.actor_target = deepcopy(self.actor)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)

        self.critic1 = Critic(state_dim, action_dim)
        self.critic2 = Critic(state_dim, action_dim)
        self.critic1_target = deepcopy(self.critic1)
        self.critic2_target = deepcopy(self.critic2)
        self.critic1_opt = torch.optim.Adam(self.critic1.parameters(), lr=lr)
        self.critic2_opt = torch.optim.Adam(self.critic2.parameters(), lr=lr)

    def select_action(self, state, noise=0.1):
        state = torch.FloatTensor(state).unsqueeze(0)
        action = self.actor(state).cpu().data.numpy().flatten()
        action += np.random.normal(0, noise, size=action.shape)
        return np.clip(action, -1, 1)

    def train(self, replay_buffer, batch_size=256):
        s, a, r, s_next, done = replay_buffer.sample(batch_size)
        s, a, r, s_next, done = map(lambda x: torch.FloatTensor(x), [s, a, r, s_next, done])
        r = r.unsqueeze(-1)
        done = done.unsqueeze(-1)

        # Target policy smoothing
        noise = (torch.randn_like(a) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
        next_a = self.actor_target(s_next) + noise

        # Double Q target
        target_q1 = self.critic1_target(s_next, next_a)
        target_q2 = self.critic2_target(s_next, next_a)
        target_q = r + (1 - done) * self.gamma * torch.min(target_q1, target_q2)

        # Critic update
        current_q1 = self.critic1(s, a)
        current_q2 = self.critic2(s, a)
        loss1 = F.mse_loss(current_q1, target_q)
        loss2 = F.mse_loss(current_q2, target_q)
        self.critic1_opt.zero_grad()
        self.critic2_opt.zero_grad()
        loss1.backward()
        loss2.backward()
        self.critic1_opt.step()
        self.critic2_opt.step()

        # Delayed actor update
        self.total_it += 1
        if self.total_it % self.policy_delay == 0:
            # Policy gradient
            actor_loss = -self.critic1(s, self.actor(s)).mean()
            self.actor_opt.zero_grad()
            actor_loss.backward()
            self.actor_opt.step()

            # Soft update targets
            for param, target in zip(self.actor.parameters(), self.actor_target.parameters()):
                target.data.copy_(self.tau * param.data + (1 - self.tau) * target.data)
            for param, target in zip(self.critic1.parameters(), self.critic1_target.parameters()):
                target.data.copy_(self.tau * param.data + (1 - self.tau) * target.data)
            for param, target in zip(self.critic2.parameters(), self.critic2_target.parameters()):
                target.data.copy_(self.tau * param.data + (1 - self.tau) * target.data)

        return loss1.item(), loss2.item(), actor_loss.item() if self.total_it % self.policy_delay == 0 else None

Hyperparamètres

Hyperparamètre Valeur typique Description
policy_delay 2 Mettre à jour l’actor tous les N steps (valeur par défaut dans le papier: 2)
policy_noise 0.2 Écart-type du bruit ajouté aux actions cibles
noise_clip 0.5 Plage de clipping du bruit (0.5 dans le papier original)
gamma 0.99 Facteur d’actualisation (discount factor)
tau 0.005 Taux de soft update pour les cibles
lr 3e-4 Learning rate Adam pour tous les optimiseurs
hidden_dim 256 Taille des couches cachées (256 dans le papier)

Avantages

  1. Performance supérieure au DDPG : Sur les benchmarks MuJoCo (HalfCheetah, Walker2d, Ant, Hopper), TD3 surpasse systématiquement le DDPG, parfois de façon significative. Les trois corrections éliminent la divergence causée par la surestimation.
  2. Stabilité : Moins sensible aux hyperparamètres que le DDPG grâce aux double critics et au delayed update.
  3. Simplicité : Aucune architecture spéciale nécessaire — juste deux copies des networks critic et un timer pour le delayed update.

Limites

  1. Exploration : Comme le DDPG, TD3 utilise une exploration par bruit Gaussien sur les actions, qui n’est pas optimale pour des politiques complexes. Des méthodes comme SAC avec maximisation d’entropie offrent une exploration plus intelligente.
  2. Deux critics : Double le coût computationnel pour les mises à jour de critique (2 forward + 2 backward au lieu de 1).
  3. Hyperparamètres supplémentaires : policy_delay, policy_noise, noise_clip augmentent l’espace de recherche par rapport au DDPG. Un mauvais choix peut dégrader les performances.

4 cas d’usage concrets

1. Robotique — bras manipulateurs

Pour contrôler un bras robotique avec des articulations continues, TD3 apprend des trajectoires précises et fluides. Les double critics évitent que le robot ne se trompe sur la qualité d’un mouvement et ne rate sa cible.

2. Conduite autonome

Dans des simulateurs de conduite, TD3 contrôle l’accélération, le freinage et la direction en continu. Le delayed update garantit que les décisions critiques (freiner? tourner?) sont basées sur une évaluation fiable.

3. Trading algorithmique

TD3 peut apprendre à ajuster continuellement la taille des positions d’un portefeuille. Les double critics évitent la surestimation des profits potentiels, tandis que le target policy smoothing rend la stratégie plus robuste aux variations du marché.

4. Controle de drones

Le pilotage d’un drone nécessite un contrôle précis des moteurs (action continue). TD3 apprend des politiques de vol stables, et le target policy smoothing empêche les commandes brusques qui pourraient causer un crash.


TD3 vs autres algorithmes d’action continue

TD3 vs SAC

Le SAC utilise la maximisation d’entropie pour une exploration plus intelligente, tandis que TD3 utilise du bruit Gaussien simple. SAC est généralement plus sample-efficient, mais TD3 est plus simple à implémenter et peut surpasser SAC sur certains environnements avec un bon tuning.

TD3 vs PPO

PPO works en mode policy gradient pur (sans critic target double) avec clipping de la fonction objective. PPO est plus stable mais peut être moins performant sur des tâches nécessitant une précision continue fine. TD3 est off-policy (réutilise les expériences) tandis que PPO est on-policy.

TD3 vs DDPG

C’est une amélioration directe : même architecture de base (actor-critic, replay buffer, soft update), mais avec les trois corrections qui éliminent les faiblesses du DDPG. TD3 est le DDPG que le DDPG aurait dû être.


Conclusion

TD3 est devenue la baseline de référence pour les problèmes d’action continue. Ses trois corrections sont simples mais efficaces, démontrant que les problèmes fondamentaux du reinforcement learning (surestimation, variance, instabilité) peuvent être résolus par des modifications architecturales ciblées plutôt que par des changements radicaux d’algorithme.

Le succès de TD3 a inspiré des extensions comme TD3+BC qui ajoute une contrainte de comportement pour l’apprentissage offline (apprentissage à partir de données sans interaction avec l’environnement), un domaine en pleine croissance.


Voir aussi


Laisser un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *

Ce site utilise Akismet pour réduire les indésirables. En savoir plus sur la façon dont les données de vos commentaires sont traitées.