Mean Shift : Guide Complet — Principes, Exemples et Implémentation Python

Mean Shift : Guide Complet — Principes, Exemples et Implémentation Python

Mean Shift : Guide complet — Principes, Exemples et Implémentation Python

Résumé

Mean Shift (décalage moyen) est un algorithme de clustering non paramétrique fondé sur l’estimation de densité par noyau. Contrairement à K-Means, il n’exige pas que l’utilisateur fixe le nombre de clusters à l’avance. Au lieu de cela, Mean Shift recherche automatiquement les modes (pics) de la distribution de densité sous-jacente des données, et affecte chaque point d’observation au mode vers lequel il converge. Cette approche fondée sur la densité lui permet de détecter des clusters de formes arbitraires et d’adapter le nombre de clusters à la structure réelle des données.

Dans ce guide complet, nous explorons les principes mathématiques du Mean Shift, son intuition géométrique, son implémentation pratique avec scikit-learn, le réglage des hyperparamètres, et quatre cas d’usage concrets.


Principe mathématique de Mean Shift

Estimation de densité par noyau

Le fondement théorique de Mean Shift repose sur l’estimation de densité par noyau (Kernel Density Estimation, KDE). Étant donné un ensemble de n points de données {x₁, x₂, …, xₙ} dans ℝᴰ, l’estimateur de densité par noyau au point x s’écrit :

f̂(x) = (1 / n·hᴰ) · Σᵢ₊₁ⁿ K((x – xᵢ) / h)

où :

  • K est la fonction noyau (généralement une gaussienne ou un noyau d’Epanechnikov),
  • h est le paramètre de lissage appelé bandwidth (bande passante),
  • d est la dimension de l’espace des données.

L’estimateur KDE attribue à chaque point x une valeur de densité proportionnelle au nombre de points d’observation situés dans son voisinage, pondérés par le noyau. Plus un point se trouve dans une région dense de données, plus sa densité estimée est élevée.

Gradient ascent vers le mode local

L’objectif de Mean Shift est de trouver les modes de la fonction de densité estimée f̂(x). Un mode est un maximum local de la fonction de densité — c’est-à-dire un point où le gradient s’annule et la matrice hessienne est définie négative.

Pour atteindre un mode, Mean Shift utilise la montée de gradient (gradient ascent) sur l’estimateur de densité. Le gradient de f̂(x) s’exprime comme :

∇f̂(x) = (2 / n·hᴰ⁺²) · Σᵢ₊₁ⁿ (x – xᵢ) · g(||(x – xᵢ)/h||²)

où g est la fonction dérivée du noyau (profile derivative). Ce gradient pointe toujours dans la direction de la plus forte augmentation de la densité locale.

Le vecteur de mean shift

Le vecteur de mean shift mₕ(x) au point x est défini comme la moyenne pondérée des voisins dans la bande de bandwidth h, relocalisée par rapport à x :

mₕ(x) = [Σᵢ₊₁ⁿ xᵢ · g(||(x – xᵢ)/h||²)] / [Σᵢ₊₁ⁿ g(||(x – xᵢ)/h||²)] – x

Ce vecteur possède une propriété fondamentale : il est proportionnel au gradient normalisé de l’estimateur de densité :

∇f̂(x) ∝ f̂_G(x) · mₕ(x)

où f̂_G(x) est l’estimateur KDE utilisant le noyau G(z) = g(||z||²). Cela signifie que le vecteur de mean shift pointe toujours dans la direction de la densité croissante.

Convergence vers les modes de la densité

L’algorithme Mean Shift procède par itérations successives. À chaque étape t, on met à jour la position d’un point selon :

x⁽ᵗ⁺¹⁾ = x⁽ᵗ⁾ + mₕ(x⁽ᵗ⁾)

c’est-à-dire :

x⁽ᵗ⁺¹⁾ = [Σᵢ₊₁ⁿ xᵢ · g(||(x⁽ᵗ⁾ – xᵢ)/h||²)] / [Σᵢ₊₁ⁿ g(||(x⁽ᵗ⁾ – xᵢ)/h||²)]

Cette itération est garantie de converger vers un mode local de la densité estimée, sous des conditions générales sur le noyau. La preuve de convergence repose sur le fait que la séquence {f̂(x⁽ᵗ⁾)} est monotone croissante et bornée supérieurement.

Une fois que tous les points ont convergé vers leurs modes respectifs, on regroupe dans un même cluster tous les points qui aboutissent au même mode. Le nombre de clusters est donc déterminé automatiquement : c’est le nombre de modes distincts trouvés.


Intuition géométrique : des billes dans une vallée

Pour comprendre Mean Shift sans formules, imaginez le scénario suivant.

Vous disposez d’un terrain en trois dimensions dont l’altitude représente la densité de vos données. Les régions où les points sont nombreux forment des collines ; les régions vides sont des vallées.

Maintenant, placez une bille à chaque point de données, comme si vous semiez des billes sur ce relief. Chaque bille roule vers le haut de la pente la plus raide (c’est l’équivalent de la montée de gradient). Les billes voisines convergent naturellement vers le même sommet de colline.

Toutes les billes qui aboutissent au même sommet appartiennent au même cluster. Les points situés sur les flancs de la colline glissent vers le centre de gravité local, puis continuent à remonter jusqu’à atteindre le pic.

Cette analogie explique plusieurs propriétés clés :

  • Pas besoin de spécifier le nombre de clusters : le nombre de sommets naturels détermine automatiquement le nombre de groupes.
  • Clusters de forme arbitraire : contrairement à K-Means qui favorise des clusters sphériques, les bassins d’attraction de Mean Shift peuvent suivre des formes complexes.
  • Sensibilité au bandwidth : si la bande est trop large, toutes les billes convergent vers un seul sommet. Si elle est trop étroite, chaque petite bosse devient un mode et on obtient trop de micro-clusters.

L’estimation heuristique de bandwidth suivant la règle de Silverman est souvent utilisée comme point de départ, puis affinée empiriquement.


Implémentation Python avec scikit-learn

Exemple de base avec MeanShift

Voici une implémentation complète utilisant scikit-learn :

from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import numpy as np

# Générer des données synthétiques
X, y_true = make_blobs(
    n_samples=500,
    centers=4,
    cluster_std=1.0,
    random_state=42
)

# Estimation automatique du bandwidth
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=200)
print(f"Bandwidth estimé : {bandwidth:.3f}")

# Exécuter Mean Shift
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

n_clusters = len(set(labels))
print(f"Nombre de clusters détectés : {n_clusters}")

# Visualisation
plt.figure(figsize=(10, 6))
colors = plt.cm.tab10(np.linspace(0, 1, n_clusters))

for i, color in enumerate(colors):
    mask = labels == i
    plt.scatter(X[mask, 0], X[mask, 1], c=[color],
                label=f"Cluster {i}", alpha=0.6, s=30)

# Tracer les centres
plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1],
            c="red", marker="X", s=200, edgecolors="black",
            label="Centres", zorder=5)

plt.title(f"Mean Shift — {n_clusters} clusters détectés (bandwidth={bandwidth:.2f})")
plt.legend()
plt.tight_layout()
plt.show()

Comparaison avec K-Means

Voici comment comparer Mean Shift et K-Means sur le même jeu de données :

from sklearn.cluster import KMeans

# K-Means (en supposant qu'on connaisse le bon K)
kmeans = KMeans(n_clusters=4, random_state=42, n_init="auto")
kmeans_labels = kmeans.fit_predict(X)

# Mean Shift
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms_labels = ms.fit_predict(X)

# Comparaison visuelle
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# K-Means
ax1 = axes[0]
for i in range(4):
    mask = kmeans_labels == i
    ax1.scatter(X[mask, 0], X[mask, 1], alpha=0.6, s=30,
                label=f"Cluster {i}")
ax1.scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1],
            c="red", marker="X", s=200, edgecolors="black")
ax1.set_title("K-Means (K=4 fixé)")
ax1.legend()

# Mean Shift
ax2 = axes[1]
unique_labels = set(ms_labels)
for i, label in enumerate(unique_labels):
    mask = ms_labels == label
    ax2.scatter(X[mask, 0], X[mask, 1], alpha=0.6, s=30,
                label=f"Cluster {i}")
ax2.scatter(ms.cluster_centers_[:, 0], ms.cluster_centers_[:, 1],
            c="red", marker="X", s=200, edgecolors="black")
ax2.set_title(f"Mean Shift ({len(unique_labels)} clusters automatiques)")
ax2.legend()

plt.tight_layout()
plt.show()

Visualisation de la convergence

On peut observer le chemin de convergence de chaque point vers son mode :

# Exécuter Mean Shift avec calcul des centres
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)

# Visualiser la convergence pour un sous-échantillon
sample_indices = np.random.choice(len(X), size=80, replace=False)

fig, ax = plt.subplots(figsize=(8, 6))

# Points de données
ax.scatter(X[:, 0], X[:, 1], c=ms.labels_, cmap="tab10",
           alpha=0.3, s=20)

# Lignes de convergence
for idx in sample_indices:
    point = X[idx]
    center = ms.cluster_centers_[ms.labels_[idx]]
    ax.plot([point[0], center[0]], [point[1], center[1]],
            color="gray", alpha=0.15, linewidth=0.8)

# Centres finaux
ax.scatter(ms.cluster_centers_[:, 0], ms.cluster_centers_[:, 1],
           c="red", marker="X", s=200, edgecolors="black", zorder=5)
ax.set_title("Chemins de convergence vers les modes")
plt.tight_layout()
plt.show()

Hyperparamètres de Mean Shift

Le choix des hyperparamètres est crucial pour obtenir de bons résultats avec Mean Shift.

bandwidth (bande passante)

C’est le paramètre le plus important. Il contrôle le rayon de la fenêtre de Parzen utilisée pour estimer la densité.

  • Trop petit : chaque petit groupe de points devient son propre cluster → sur-segmentation
  • Trop grand : tous les points fusionnent en un seul cluster → sous-segmentation
  • Valeur idéale : correspond à l’échelle caractéristique des structures que vous souhaitez détecter

L’heuristique estimate_bandwidth(X, quantile=0.2) calcule le quantile des distances par paires, ce qui fournit un bon point de départ. Ajustez ensuite le quantile (entre 0.1 et 0.5) selon la granularité souhaitée.

bin_seeding

Lorsque bin_seeding=True, l’algorithme discrétise l’espace en bacs (bins) et n’initialise les graines de clustering qu’une fois par bac occupé. Cela réduit considérablement le temps de calcul, surtout sur de grands jeux de données. La taille des bacs est égale au bandwidth.

min_bin_freq

Ce paramètre contrôle le nombre minimal de points qu’un bac doit contenir pour être considéré comme une graine valide. Avec min_bin_freq=1 (valeur par défaut), chaque bac occupé devient une graine. Augmenter cette valeur ignore les régions peu denses et accélère le calcul, au prix de potentielles omissions de petits clusters.

cluster_all

Détermine le traitement des points orphelins qui ne convergent pas clairement vers un mode.

  • cluster_all=True (par défaut) : tous les points sont assignés, y compris les cas limites qui sont rattachés au cluster le plus proche
  • cluster_all=False : les points non assignés reçoivent le label -1

n_jobs

Nombre de processeurs utilisés en parallèle. Sur un processeur moderne, n_jobs=-1 exploite tous les cœurs disponibles et réduit significativement le temps de calcul de l’estimation de densité.


Avantages et limites de Mean Shift

Avantages

  1. Pas besoin de fixer le nombre de clusters — le nombre de modes est déterminé automatiquement par la structure des données.
  2. Clusters de forme arbitraire — les bassins d’attraction ne sont pas contraints à être sphériques ou convexes.
  3. Robuste aux outliers — les points isolés convergent vers leur mode propre ou sont identifiés comme orphelins (selon cluster_all).
  4. Fondement théorique solide — basé sur l’estimation de densité par noyau, avec des garanties de convergence prouvées.
  5. Interprétabilité — chaque cluster correspond à un mode identifiable de la distribution sous-jacente.

Limites

  1. Coût computationnel élevé — la complexité naive est O(n²·d·T) où T est le nombre d’itérations. Pour n > 10 000 points, cela devient prohibitif sans bin_seeding.
  2. Sensibilité au bandwidth — choisir un bandwidth inadapté donne des résultats médiocres. Il n’existe pas de méthode universelle automatique pour le sélectionner.
  3. Difficulté en haute dimension — en grande dimension, l’estimation de densité par noyau souffre du fléau de la dimensionnalité : tous les points deviennent également proches et la notion de densité locale perd son sens.
  4. Clusters de densités très différentes — un bandwidth unique peut mal fonctionner si les clusters ont des densités très hétérogènes (un cluster très dense et un autre très diffus).
  5. Pas de modèle probabiliste — contrairement aux modèles de mélanges gaussiens, Mean Shift ne fournit pas de probabilité d’appartenance ni de vraisemblance.

4 cas d’usage concrets du Mean Shift

Cas 1 — Segmentation d’images

Mean Shift est historiquement l’algorithme de référence en segmentation d’images. Chaque pixel est représenté comme un point dans un espace combinant coordonnées spatiales (x, y) et couleur (R, G, B). Le bandwidth contrôle le compromis entre la fidélité aux couleurs et la régularité spatiale.

from sklearn.cluster import MeanShift
from sklearn.cluster import estimate_bandwidth
import numpy as np
from skimage import data

# Charger une image
image = data.astronaut()
height, width, channels = image.shape

# Préparer les données : concaténer position et couleur
X_spatial = np.mgrid[0:height, 0:width].reshape(2, -1).T
X_color = image.reshape(-1, 3).astype(float)
X = np.hstack([X_spatial * 0.3, X_color])

# Mean Shift
bandwidth = estimate_bandwidth(X, quantile=0.05, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X[::10])  # échantillonner pour accélérer

# Reconstruire l'image segmentée
labels_full = ms.predict(X)
segmented = labels_full.reshape(height, width)
print(f"Image segmentée en {len(set(labels_full))} régions")

Cas 2 — Analyse de trajectoires GPS

En géolocalisation, Mean Shift permet d’identifier les lieux fréquemment visités (points d’intérêt) à partir de traces GPS. Les points se concentrent naturellement autour des lieux visités, formant des modes de densité spatiale.

# Coordonnées GPS simulées
gps_data = np.vstack([
    np.random.randn(200, 2) * 0.001 + [48.8566, 2.3522],
    np.random.randn(150, 2) * 0.002 + [48.8738, 2.2950],
    np.random.randn(100, 2) * 0.001 + [48.8448, 2.3736],
    np.random.randn(80, 2) * 0.003 + [48.8600, 2.3400],
])

bandwidth = estimate_bandwidth(gps_data, quantile=0.15)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(gps_data)

for i, center in enumerate(ms.cluster_centers_):
    print(f"Lieu fréquent #{i+1} : lat={center[0]:.4f}, lon={center[1]:.4f} "
          f"({np.sum(ms.labels_==i)} points)")

Cas 3 — Détection de communautés dans les réseaux sociaux

En représentant les utilisateurs comme des points dans un espace latent (embeddings issus de Node2Vec, par exemple), Mean Shift détecte automatiquement les communautés sans imposer leur nombre. Les groupes d’utilisateurs aux profils similaires forment des modes denses.

from sklearn.decomposition import PCA

# X_embedded : matrice d'embeddings (n_users, 300)
pca = PCA(n_components=50)
X_reduced = pca.fit_transform(X_embedded)

bandwidth = estimate_bandwidth(X_reduced, quantile=0.1, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=-1)
ms.fit(X_reduced)

n_communities = len(set(ms.labels_))
print(f"{n_communities} communautés détectées")
for c in range(n_communities):
    print(f"  Communauté {c}: {np.sum(ms.labels_==c)} utilisateurs")

Cas 4 — Identification de régimes climatiques

En climatologie, Mean Shift est utilisé pour identifier des régimes de circulation atmosphérique à partir de données de géopotentiel. Chaque mode correspond à un patron de circulation récurrent, ce qui aide à comprendre la variabilité climatique naturelle.

# Données climatiques simulées
X_climate = np.vstack([
    np.random.randn(500, 3) * 0.5 + [1, 0.5, -0.3],
    np.random.randn(350, 3) * 0.4 + [-1, -0.5, 0.3],
    np.random.randn(250, 3) * 0.6 + [0.5, -1, 0.8],
    np.random.randn(400, 3) * 0.3 + [-0.3, 1, -0.5],
])

bandwidth = estimate_bandwidth(X_climate, quantile=0.2)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X_climate)

print(f"{len(set(ms.labels_))} régimes climatiques identifiés")
for i, center in enumerate(ms.cluster_centers_):
    n_points = np.sum(ms.labels_ == i)
    print(f"  Régime {i+1}: centre={center.round(2)}, {n_points} jours")

Quand choisir Mean Shift plutôt qu’un autre algorithme ?

Choisissez Mean Shift quand :

  • Le nombre de clusters est inconnu et vous souhaitez que l’algorithme le détermine lui-même
  • Vous travaillez avec des données de dimension modérée (d < 50)
  • Les clusters ont potentiellement des formes non sphériques
  • Le jeu de données est de taille raisonnable (n < 10 000 sans bin_seeding, ou n < 100 000 avec bin_seeding)
  • Vous avez besoin d’une méthode fondée sur la densité mais que DBSCAN échoue à capturer la structure

Évitez Mean Shift quand :

  • Les données sont en très haute dimension (> 100) — privilégiez une réduction de dimensionnalité préalable
  • Le jeu de données dépasse quelques centaines de milliers de points — le coût computationnel devient prohibitif
  • Les clusters ont des densités très hétérogènes — DBSCAN ou HDBSCAN seront plus adaptés
  • Vous avez besoin d’affectations probabilistes — tournez-vous vers les Gaussian Mixture Models

Voir aussi

Laisser un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *

Ce site utilise Akismet pour réduire les indésirables. En savoir plus sur la façon dont les données de vos commentaires sont traitées.