GraphSAGE : Guide complet — Généralisation d’Embeddings de Graphes
Résumé — GraphSAGE (Graph Sample and Aggregate), introduit par Hamilton, Ying et Leskovec en 2017, est une approche inductive pour générer des embeddings de noeuds. Contrairement aux méthodes transductives (DeepWalk, node2vec, GCN classique) qui nécessitent de connaître tous les noeuds pendant l’entraînement, GraphSAGE apprend une fonction d’agrégation qui peut générer des embeddings pour des noeuds jamais vus pendant l’entraînement, en utilisant uniquement leurs features et leur voisinage. Cette propriété inductive est critique pour les applications où le graphe évolue dynamiquement (réseaux sociaux, recommandation de produits, webs).
Principe mathématique
1. Propagation par message
Comme les GNN classiques, GraphSAGE agrège les features du voisinage pour mettre à jour chaque noeud.
2. Formulation générique
À chaque couche k, le noeud v agrège les embeddings de ses voisins N(v) :
h_v^{(k)} = sigma(W_k · AGGREGATE([h_v^{(k-1)}, {h_u^{(k-1)} pour u dans N(v)}]))
La clé de GraphSAGE est la fonction AGGREGATE, qui doit être invariante à la permutation des voisins et différentiable.
3. Types d’agrégateurs
GraphSAGE propose trois agrégateurs flexibles :
Mean Aggregator (similaire au GCN mais sans normalization par degré) :
AGG = mean([h_v^{(k-1)}, mean(h_u^{(k-1)} pour u dans N(v))])
h_v^{(k)} = sigma(W · AGG)
C’est l’agrégateur le plus simple et le plus efficace en pratique.
LSTM Aggregator : Utilise un LSTM pour traiter les embeddings des voisins dans un ordre aléatoire (car le LSTM n’est pas invariant à la permutation par nature). Le LSTM capture des dépendances non linéaires entre voisins mais est plus coûteux en calcul.
Pooling Aggregator : Applique une transformation non linéaire suivie de max-pooling :
AGG = max({sigma(W_pool · h_u^{(k-1)} + b) pour u dans N(v)})
Le max-pooling agit comme un détecteur de patterns : chaque dimension du pooling répond à un pattern différent parmi les voisins.
4. Échantillonnage de voisinage
Pour les très grands graphes, GraphSAGE échantillonne un sous-ensemble fixe de voisins à chaque couche :
S_1 voisins échantillonnés à la couche 1
S_2 voisins échantillonnés à la couche 2
Par exemple, avec K=2 couches, S_1=25 et S_2=10, chaque noeud agrège au maximum 10+25=35 voisins, indépendamment du degré réel du noeud.
Intuition
Un réseau transductif (node2vec, GCN classique) est comme un annuaire téléphonique : il connaît tout le monde mais est perdu face à un inconnu. GraphSAGE est comme un détective : il observe les fréquentations et les traits d’une personne qu’il n’a jamais rencontrée et déduit qui elle est.
Imaginez que vous arrivez dans une nouvelle ville. Au lieu de mémoriser chaque habitant (transductif), vous apprenez à reconnaître les quartiers par leur ambiance : les cafés, les magasins, les écoles présents. Quand vous rencontrez un nouvel habitant, vous le classez par le quartier où il habite. GraphSAGE fait exactement cela pour les noeuds.
Implémentation Python
1. Mean Aggregator
import torch
import torch.nn as nn
import torch.nn.functional as F
class MeanAggregator(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.weight = nn.Linear(2 * in_dim, out_dim)
def forward(self, self_feat, neighbor_feats):
neighbor_mean = neighbor_feats.mean(dim=-2)
combined = torch.cat([self_feat, neighbor_mean], dim=-1)
return F.relu(self.weight(combined))
class GraphSAGE(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(MeanAggregator(in_dim, hidden_dim))
for _ in range(num_layers - 1):
self.layers.append(MeanAggregator(hidden_dim, hidden_dim))
self.classifier = nn.Linear(hidden_dim, out_dim)
def forward(self, adj, features, sample_sizes):
h = features
for i, s_size in enumerate(sample_sizes):
neighbors = self.sample_neighbors(adj, s_size)
neighbor_feats = torch.index_select(h, 0, neighbors.view(-1))
neighbor_feats = neighbor_feats.view(adj.size(0), s_size, -1)
self_feat = h
h = self.layers[i](self_feat, neighbor_feats)
h = F.normalize(h, p=2, dim=-1)
return self.classifier(h)
2. Entraînement
import torch.optim as optim
model = GraphSAGE(in_dim=1433, hidden_dim=128, out_dim=7, num_layers=2)
optimizer = optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
sample_sizes = [25, 10]
for epoch in range(200):
model.train()
out = model(adj, features, sample_sizes)
loss = F.cross_entropy(out[mask_train], labels[mask_train])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 50 == 0:
model.eval()
with torch.no_grad():
out = model(adj, features, sample_sizes)
test_acc = (out[mask_test].argmax(1) == labels[mask_test]).float().mean()
print(f'Epoch {epoch} | Loss: {loss:.4f} | Test acc: {test_acc:.3f}')
Hyperparamètres
| Hyperparamètre | Valeur typique | Description |
|---|---|---|
| num_layers | 2 | Nombre de couches d’agrégation (K=2 suffit pour la plupart des tâches) |
| sample_sizes | [25, 10] | Nombre de voisins échantillonnés par couche (S_1=25, S_2=10) |
| aggregator | mean | mean/pooling/lstm (mean est le plus courant et efficace) |
| hidden_dim | 128-256 | Dimensions de l’espace latent caché par couche |
| lr | 1e-2 | Learning rate Adam |
| dropout | 0.5 | Régularisation entre les couches |
Avantages
- Inductif : Génère des embeddings pour des noeuds jamais vus, contrairement aux méthodes transductives comme DeepWalk ou node2vec qui nécessitent de réentraîner quand le graphe change.
- Scalable : L’échantillonnage de voisinage contrôle la complexité computationnelle pour les très grands graphes avec des millions de noeuds.
- Flexible : Différents agrégateurs (mean, LSTM, pooling) pour différents besoins. L’agrégateur mean est rapide, le LSTM capture des patterns complexes.
- Features riches : Peut utiliser n’importe quel type de features de noeud (texte, image, attributs) contrairement aux méthodes basées uniquement sur la structure du graphe.
- Transfert facile : Le modèle entraîné sur un graphe peut être transféré à un autre graphe de même type de features.
Limites
- Perte d’information structurale globale : L’échantillonnage réduit la vision du voisinage, perdant des relations lointaines et des patterns structurels globaux.
- Hyperparamètres sensibles : sample_sizes influence fortement la performance et la mémoire. Un mauvais choix peut détruire les performances.
4 cas d’usage concrets
1. Recommandation de produits
Dans un réseau utilisateur-produit, GraphSAGE génère des embeddings pour de nouveaux produits sans re-entraînement, permettant l’ajout en temps réel au catalogue.
2. Analyse de réseaux sociaux
Détection de communautés avec des membres qui rejoignent le réseau en continu. Les embeddings des nouveaux utilisateurs sont générés à la volée sans recalculer tout le graphe.
3. Classification de protéines
Classification de nouvelles protéines basée sur leur structure et fonction voisines, utile dans la découverte de médicaments.
4. Détection de spam
Les comptes frauduleux ont souvent des patterns de connexion anormaux que les agrégateurs capturent facilement, même pour de nouveaux comptes.
GraphSAGE et les reseaux de recommendation
GraphSAGE est l’algorithme derriere le systeme de recommandation de Pinterest, qui traite des millions de pins et de users. Au lieu de stocker un embedding par utilisateur, ils calculent l’embedding a la volee en agregeant les interactions recentes. Quand un nouveau user s’inscrit, il a immediatement un embedding base sur ses premieres actions.
SAGE vs GAT vs GCN – Tableau recapitulatif
| Propriete | GCN | GraphSAGE | GAT |
|---|---|---|---|
| Type | Transductif | Inductif | Inductif |
| Voisins | Tous (matrice) | Echantillonnes | Echantillonnes |
| Poids | Degre normalise | Uniforme/Appris | Attention |
| Scalabilite | O(N^2) | O(N*k) | O(Nkh) |
| Nouveaux noeuds | Non | Oui | Oui |
GraphSAGE est le meilleur compromis entre performance et scalabilite pour les graphes massifs avec des noeuds dynamiques. C’est l’algorithme le plus deploye dans l’industrie aujourd’hui.
Analyse en profondeur de l’agregation mean-pooling vs LSTM-pooling vs GCN-pooling
Mean Pooling
L’agregation par moyenne est la plus simple et la plus rapide. Elle prend la moyenne element-wise des embeddings de voisins et la concatene avec l’embedding du noeud cible. Cette approche ne tient pas compte de l’ordre des voisins ni de leur importance relative. Malgre sa simplicite, elle fonctionne remarquablement bien sur de nombreux benchmarks et reste le choix par defaut recommande.
LSTM Pooling
Le pooling par LSTM utilise des reseaux LSTM sequentiels pour traiter les voisins un par un dans un ordre aleatoire. L’etat cach final du LSTM sert d’agregation. Cette methode est plus expressive car elle capture des interactions sequentielles entre voisins. Cependant, le fait que l’ordre soit aleatoire limite l’avantage du LSTM par rapport a un simple mean pooling. En pratique, le gain de performance est marginal pour un cout computationnel significativement eleve.
GCN Pooling
Le GCN pooling utilise une moyenne ponderee par le degre normalise. C’est un cas special de l’agregation de GraphSAGE ou les poids sont fixes par la structure du graphe plutot qu’appris. GraphSAGE generalise cette approche en permettant des fonctions d’agregation apprise.
Benchmark sur datasets references
Sur le dataset Reddit (230K noeuds, 114M aretes), GraphSAGE atteint un score F1 micro de 0.954 pour la classification de communautes, comparables au meilleur modeles de l’epoque. Sur PPI (56K noeuds, 818K aretes), GraphSAGE atteint F1 micro de 0.6120 contre 0.504 pour les methodes lineaires. Ces resultats demontrent la superiorite de l’approche d’agregation du voisinage sur les graphes biologiques et sociaux.
La scalabilite de GraphSAGE impressionne : sur un seul GPU, l’entrainement sur le dataset Reddit complet prend environ 25 minutes. Le GCN necessite de charger toute la matrice d’adjacence en memoire ce qui est impossible pour des graphes de cette taille. GraphSAGE, par son echantillonnage, permet l’entrainement sur des graphes de centaines de millions de noeuds.
Voir aussi
- Maîtrisez les Nombres Triffle en Python : Guide Complet pour Développeurs
- Démystifiez le Module ‘abc’ en Python : Guide Complet avec Exemples Pratiques

