Graph Attention Network : Guide complet — Attention sur Graphes
Résumé — Le Graph Attention Network (GAT), proposé par Veličković et al. en 2017, améliore les GCN en remplaçant l’agrégation uniforme des voisins par un mécanisme d’attention appris. Chaque voisin reçoit un poids d’attention spécifique et dynamique, permettant au modèle de se focaliser sur les voisins les plus pertinents. Avec l’attention multi-têtes, le GAT capture simultanément plusieurs sous-espaces d’information, atteignant des performances supérieures sur la classification de nœuds (Cora, Citeseer, PubMed).
Principe mathématique
1. Mécanisme d’attention sur graphe
Pour chaque nœud i, le GAT calcule un score d’attention avec chacun de ses voisins j ∈ N(i) :
e_ij = LeakyReLU(a^T [Wh_i || Wh_j])
Où :
– h_i ∈ R^F est la représentation du nœud i
– W ∈ R^{F'×F} est une matrice de transformation linéaire partagée
– a ∈ R^{2F'} est un vecteur d’attention appris
– || désigne la concaténation
– LeakyReLU avec pente négative typique de 0.2
2. Normalisation softmax
Les scores d’attention sont normalisés sur tous les voisins du nœud i :
α_ij = softmax_j(e_ij) = exp(e_ij) / somme_{k∈N(i)} exp(e_ik)
Cela garantit que les coefficients d’attention forment une distribution de probabilité : somme_j α_ij = 1.
3. Agrégation pondérée
La représentation mise à jour du nœud i est une combinaison linéaire pondérée de ses voisins :
h_i' = sigma(somme_{j∈N(i)} α_ij · Wh_j)
Où sigma est une fonction d’activation non linéaire (ELU typiquement).
4. Attention multi-têtes
Pour stabiliser l’apprentissage et capturer différents aspects de la structure, le GAT utilise M têtes d’attention en parallèle :
Couches intermédiaires (concaténation) :
h_i' = ||_{m=1..M} sigma(somme_j α_ij^m · W^m h_j)
Couche finale (moyenne) :
h_i' = sigma(moyenne_{m=1..M} somme_j α_ij^m · W^m h_j)
La concaténation multiplie la dimension de sortie par M, tandis que la moyenne la conserve.
5. Masquage (Masked Attention)
L’attention est restreinte aux voisins immédiats dans le graphe (masked). Contrairement à l’attention du Transformer qui calcule les corrélations entre toutes les paires, le GAT ne considère que les connexions existantes, respectant la topologie du graphe.
Intuition
Imaginez que vous cherchez des conseils pour choisir un restaurant dans une ville inconnue.
Vous avez 50 amis sur les réseaux sociaux, mais ils ne valent pas tous le même conseil :
– Votre ami foodie qui teste des restaurants chaque semaine compte beaucoup
– Votre collègue qui n’aime pas sortir compte peu
– Votre voisin qui vit dans cette ville compte moyennement
Le GCN classique donnerait le même poids à tous vos amis — comme si chaque avis comptait exactement 1/50. Le GAT, lui, apprend automatiquement qui est le plus pertinent pour cette question spécifique.
Encore plus subtil : les têtes d’attention multiples signifient que votre foodie est excellent pour la qualité culinaire mais mauvais pour les prix, tandis qu’un autre ami est expert en bons plans. Le GAT capture ces différentes dimensions de pertinence simultanément.
Implémentation Python
1. GAT sur Cora avec PyTorch Geometric
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
# Chargement du dataset Cora
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0] # Graph unique
print(f'Cora: {data.num_nodes} noeuds, {data.num_edges} arêtes')
print(f'{data.num_classes} classes, {data.num_node_features} features')
class GAT(torch.nn.Module):
def __init__(self, hidden_channels=128, heads=8, dropout=0.6):
super().__init__()
# Premiere couche GAT avec concatenation multi-tetes
self.conv1 = GATConv(
dataset.num_node_features,
hidden_channels,
heads=heads,
dropout=dropout,
concat=True,
negative_slope=0.2
)
self.dropout = dropout
# Couche de sortie avec moyenne des tetes
self.conv2 = GATConv(
hidden_channels * heads,
dataset.num_classes,
heads=1,
concat=False,
dropout=dropout,
negative_slope=0.2
)
def forward(self, x, edge_index):
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GAT(hidden_channels=128, heads=8, dropout=0.6)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
# Entraînement
best_val_acc = 0
for epoch in range(300):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# Validation
model.eval()
_, pred = out[data.val_mask].max(dim=1)
val_correct = int((pred == data.y[data.val_mask]).sum())
val_total = int(data.val_mask.sum())
val_acc = val_correct / val_total
if val_acc > best_val_acc:
best_val_acc = val_acc
# Test
_, test_pred = out[data.test_mask].max(dim=1)
test_correct = int((test_pred == data.y[data.test_mask]).sum())
test_total = int(data.test_mask.sum())
test_acc = test_correct / test_total
if epoch % 50 == 0:
print(f'Epoch {epoch} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f} | '
f'Best Test Acc: {test_acc:.4f}')
print(f'\nMeilleure accuracy test : {test_acc:.4f}')
2. Visualisation des poids d’attention
# Extraire les coefficients d'attention
model.eval()
with torch.no_grad():
# Passer les données à travers la premiere couche
x = F.dropout(data.x, p=0.6, training=False)
_, attention_weights = model.conv1(x, data.edge_index,
return_attention_weights=True)
edge_index_att = attention_weights[0]
attention_coeffs = attention_weights[1]
# Statistiques sur les poids d'attention
att_np = attention_weights[1].mean(dim=0).numpy()
print(f'Moyenne des poids d\'attention: {att_np.mean():.4f}')
print(f'Ecart-type: {att_np.std():.4f}')
print(f'Min: {att_np.min():.4f}, Max: {att_np.max():.4f}')
print(f'95e percentile: {np.percentile(att_np, 95):.4f}')
# Visualisation de la distribution
plt.figure(figsize=(10, 6))
plt.hist(att_np, bins=50, edgecolor='black', alpha=0.7)
plt.title('Distribution des poids d\'attention (moyenne sur les tetes)')
plt.xlabel('Coefficient d\'attention')
plt.ylabel('Frequence')
plt.yscale('log')
plt.savefig('gat_attention_distribution.png', dpi=150)
print('Distribution sauvegardee !')
Hyperparamètres
| Hyperparamètre | Valeur typique | Description |
|---|---|---|
| heads | 8 | Nombre de têtes d’attention (plus = plus expressif mais plus lourd) |
| hidden_channels | 64-256 | Dimension de chaque tête (la sortie de concat est heads × hidden) |
| dropout | 0.4-0.6 | Régularisation appliquée avant chaque couche GAT |
| attention_dropout | 0.3-0.6 | Dropout spécifique aux coefficients d’attention |
| negative_slope | 0.2 | Pente du LeakyReLU dans le calcul de l’attention |
| concat (intermédiaire) | True | Concaténer les têtes aux couches intermédiaires |
| concat (finale) | False | Moyenner les têtes à la couche finale |
| learning_rate | 0.005 | Adam, avec weight_decay (5e-4) |
Avantages
- Attention apprise : Contrairement à GCN qui utilise des poids fixes basés sur les degrés, le GAT apprend dynamiquement quels voisins comptent le plus.
- Agnostique au graphe : Le GAT n’a pas besoin de la structure complète du graphe à l’avance (contrairement à GCN qui précalcule la matrice normalisée). Il fonctionne sur des graphes inductifs (nouveaux nœuds).
- Interprétabilité : Les coefficients d’attention
α_ijsont directement interprétables : ils révèlent quels voisins influencent le plus chaque nœud. - Multi-têtes : Capture plusieurs types de relations simultanément, chaque tête pouvant se spécialiser dans un aspect différent de la structure.
Limites
- Coût calcul O(V · E) : Le calcul des scores d’attention pour chaque paire de nœuds connectés est plus coûteux que l’agrégation moyenne de GCN.
- Surapprentissage sur petits graphes : Avec peu de données d’entraînement (Cora n’a que 140 labels pour l’entraînement), les 8 têtes d’attention peuvent surapprendre. Le dropout est crucial.
- Performance variable : Sur certains benchmarks, GAT ne bat pas significativement GCN simple. Les gains sont nets sur des graphes avec des hétérogénéités de voisinage.
4 cas d’usage concrets
1. Classification de documents dans un graphe de citations
L’application originale de GAT : classifier des articles scientifiques (Cora, Citeseer, PubMed) en utilisant leurs réseaux de citations. Les articles citant les mêmes articles “pivots” reçoivent plus de poids, capturant les liens sémantiques plus fidèlement que GCN. GAT atteint 83.0% sur Cora, contre 81.4% pour GCN.
2. Recommandation sociale
Dans un réseau social d’utilisateurs et de produits, le GAT apprend à qui faire confiance pour quelles recommandations. Quand un utilisateur achète un produit, le GAT pondère différemment les avis de ses amis selon leur expertise dans le domaine concerné (électronique, mode, cuisine).
3. Détection de fraude financière
Dans un graphe de transactions (nœuds = comptes, arêtes = transferts), le GAT identifie les comptes suspects en apprenant que certains types de connexions (transferts rapides, montants ronds, réseaux complexes de sous-comptes) sont plus révélateurs de fraude que d’autres.
4. Prédiction de propriétés moléculaires
Dans une molécule (graphe : atomes = nœuds, liaisons = arêtes), le GAT apprend que certains atomes voisins (double liaison avec oxygène, cycle aromatique) sont plus influents que d’autres pour prédire la solubilité ou la toxicité du composé. L’attention multi-têtes capture différentes propriétés chimiques simultanément.
Comparaison GAT vs GCN
| Aspect | GCN | GAT |
|---|---|---|
| Poids des voisins | Fixe (basé sur les degrés) | Appris par attention |
| Inductif | Non (besoin de recalculer) | Oui (fonctionne sur nouveaux nœuds) |
| Coût calcul | O(E) | O(E · F’) |
| Interprétabilité | Faible | Forte (weights d’attention visibles) |
| Performance Cora | ~81.4% | ~83.0% |
| Multi-relations | Non nativement | Oui (multi-têtes) |
Conclusion
Le Graph Attention Network a apporté une innovation clé au domaine des GNN : remplacer l’agrégation uniforme par une attention apprise. Cette idée simple mais puissante permet au modèle de distinguer les voisins pertinents du bruit, améliorant les performances tout en offrant une interprétabilité naturelle.
Le GAT a inspiré une famille d’architectures dérivées : GATv2 (attention dynamique plus expressive), AGNN (attention sans transformation linéaire préalable) et les modèles d’attention sur hypergraphes. L’attention sur graphes est aujourd’hui un composant standard dans l’arsenal du deep learning géométrique.
Voir aussi
- Comprendre les Matrices Idempotentes en Python : Guide Complet pour Développeurs
- Maîtriser la Manipulation des Grilles Croisées avec Python : Guide Ultime pour Gérer les Rectangles

