Vision Transformer (ViT)
Summary
The Vision Transformer (ViT), introduced by Dosovitskiy and his team at Google Brain in 2020, represents a major breakthrough in computer vision. For the first time, an architecture relying exclusively on multi-head attention mechanisms — without any convolution operation — surpassed classical convolutional networks on large-scale image classification tasks.
The core idea of the Vision Transformer is particularly elegant: instead of processing an image pixel by pixel with convolution filters, you cut it into rectangular patches, project each patch into a fixed-dimensional vector space, and feed this sequence of vectors into a standard Transformer encoder. This paradigm shift opens the door to a global understanding of spatial relationships within an image, whereas CNNs have traditionally been limited to a local neighborhood.
In this complete guide, we will explore in detail the mathematical principle of the Vision Transformer, its fundamental intuition, its step-by-step implementation in PyTorch, its key hyperparameters, as well as its advantages, limitations, and concrete use cases.
Mathematical Principle of the Vision Transformer
The operating principle of the Vision Transformer relies on a series of precise mathematical steps, each playing an indispensable role in transforming a raw image into a semantic representation usable for classification or other vision tasks.
Step 1 — Cutting the image into patches
Consider an input image ( x \in \mathbb{R}^{H \times W \times C} ), where ( H ) is the height in pixels, ( W ) is the width, and ( C ) is the number of color channels (usually 3 for RGB).
A fixed patch size ( P \times P ) pixels is chosen. The image is then divided into a non-overlapping grid of patches. The total number of patches is obtained by the formula:
$$
N = \frac{H \times W}{P^2}
$$
For example, a ( 224 \times 224 ) pixel image divided into ( 16 \times 16 ) patches gives ( N = \frac{224 \times 224}{16^2} = 196 ) patches. Each patch is a tensor of dimensions ( P \times P \times C ), which is “flattened” into a vector of dimension ( P^2 \cdot C ).
Step 2 — Linear Projection (Patch Embedding)
Each flattened patch is projected into a representation space of dimension ( d_{\text{model}} ) using a linear transformation (a dense layer without non-linearity):
$$
z_0^{(p)} = x_p \cdot E \quad \text{where} \quad E \in \mathbb{R}^{(P^2 \cdot C) \times d_{\text{model}}}
$$
Here, ( x_p \in \mathbb{R}^{P^2 \cdot C} ) represents the ( p )-th flattened patch, and ( z_0^{(p)} \in \mathbb{R}^{d_{\text{model}}} ) is its embedding after projection. The matrix ( E ) is learned during training and is at the heart of the embedding step. In practice, this linear projection is often implemented equivalently using a 2D convolution with kernel size ( P \times P ), stride ( P ) (stride equal to the patch size), and ( d_{\text{model}} ) output filters.
Step 3 — Adding the [CLS] Classification Token
Like the original Transformer used in natural language processing, a special vector ( z_{\text{cls}} ) is inserted at the beginning of the embedding sequence. This token, called the [CLS] token, is a learnable vector of dimension ( d_{\text{model}} ). Its role is to aggregate global information from the entire image through the attention layers.
$$
z_0 = [z_{\text{cls}}; z_0^{(1)}; z_0^{(2)}; \dots; z_0^{(N)}] \in \mathbb{R}^{(N+1) \times d_{\text{model}}}
$$
The complete sequence therefore contains ( N + 1 ) vectors: the classification token followed by the ( N ) patch embeddings.
Step 4 — Positional Embeddings
Unlike convolutional networks, the attention mechanism is inherently agnostic to position: it has no built-in notion of spatial order. To address this limitation, a positional embedding ( E_{\text{pos}} ) is added to each vector in the sequence:
$$
z_0 \leftarrow z_0 + E_{\text{pos}} \quad \text{where} \quad E_{\text{pos}} \in \mathbb{R}^{(N+1) \times d_{\text{model}}}
$$
Positional embeddings are typically learned during training (learnable positional embeddings), although some variants use fixed sinusoidal embeddings inspired by the original Transformer from Vaswani et al.
Step 5 — Passing Through the Transformer Encoder
The sequence ( z_0 ) then passes through ( L ) identical Transformer encoder layers. Each layer ( \ell ) is composed of two fundamental sublayers:
Multi-Head Attention (MHA):
$$
z_\ell’ = \text{MHA}(\text{LayerNorm}(z_{\ell-1})) + z_{\ell-1}
$$
Feed-Forward Network (MLP):
$$
z_\ell = \text{MLP}(\text{LayerNorm}(z_\ell’)) + z_\ell’
$$
Each sublayer is preceded by layer normalization (LayerNorm) and followed by a residual connection. The multi-head attention mechanism computes, for each pair of positions ( (i, j) ) in the sequence, an attention score measuring the relative relevance of token ( j ) for token ( i ). Thanks to this global attention, each patch can interact with all other patches — including those at opposite ends of the image.
Step 6 — Classification Head (MLP Head)
After the ( L ) encoding layers, the final state of the [CLS] token, denoted ( z_L^{(\text{cls})} ), is extracted and passed through a classification head typically consisting of a normalization layer (LayerNorm) followed by a linear layer projecting to the number of classes ( K ):
$$
y = \text{MLP}(z_L^{(\text{cls})}) \in \mathbb{R}^K
$$
In classification, a softmax function is then applied to obtain a probability distribution over the ( K ) possible classes:
$$
\hat{y} = \text{softmax}(y)
$$
Intuition: A Puzzle Analyzed Globally
To properly understand the intuition behind the Vision Transformer, let’s compare it to the traditional convolutional network (CNN) approach.
A CNN traverses the image with convolution filters of fixed size (e.g., ( 3 \times 3 ) or ( 5 \times 5 )). Each filter examines a small local neighborhood of pixels, detects simple patterns (edges, textures), and then deeper layers combine these patterns into more complex features. However, for a CNN to capture relationships between distant regions of the image, many layers must be stacked: the effective receptive field only grows progressively with network depth.
The Vision Transformer, on the other hand, adopts a radically different strategy. Imagine taking an image and cutting it into pieces like a puzzle. Each piece (patch) is sent to a Transformer “analyzer.” The magic happens through the attention mechanism: each fragment can communicate directly with all others, regardless of their spatial distance in the image.
In a single attention layer, a patch in the top-left corner can exchange information with a patch in the bottom-right corner. This ability to model global relationships from the first layer is the fundamental advantage of ViT. Where a CNN must “see” progressively further through the stacking of layers, the Transformer sees everything from the start.
Of course, this power comes at a price: the computational complexity of multi-head attention grows quadratically with the number of patches ( O(N^2) ), which quickly becomes expensive for high-resolution images.
Python Implementation with PyTorch
Here is a complete and pedagogical implementation of the Vision Transformer in PyTorch, step by step.
Imports and Basic Configuration
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PatchEmbedding(nn.Module):
"""Transforms the image into patches and projects each patch
into the d_model space."""
def __init__(self, img_size=224, patch_size=16, in_channels=3, d_model=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# Conv2D equivalent to linear projection per patch
self.projection = nn.Conv2d(
in_channels, d_model,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (batch, channels, height, width)
x = self.projection(x) # (batch, d_model, n_h, n_w)
x = x.flatten(2) # (batch, d_model, n_patches)
x = x.transpose(1, 2) # (batch, n_patches, d_model)
return x
Classification Token and Positional Embeddings
class VisionTransformer(nn.Module):
def __init__(
self, img_size=224, patch_size=16, in_channels=3,
d_model=768, n_heads=12, n_layers=12,
mlp_ratio=4, dropout=0.1, n_classes=1000
):
super().__init__()
self.d_model = d_model
self.patch_embed = PatchEmbedding(
img_size, patch_size, in_channels, d_model
)
n_patches = self.patch_embed.n_patches
# Learnable [CLS] token
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
# Learnable positional embeddings
self.pos_embed = nn.Parameter(
torch.zeros(1, n_patches + 1, d_model)
)
self.dropout = nn.Dropout(dropout)
# Transformer Encoder (L layers)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_heads,
dim_feedforward=d_model * mlp_ratio,
dropout=dropout,
batch_first=True,
activation='gelu'
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=n_layers
)
# Classification head
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, n_classes)
# Parameter initialization
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, x):
batch_size = x.shape[0]
# 1. Patch embedding
x = self.patch_embed(x) # (batch, n_patches, d_model)
# 2. Adding the CLS token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (batch, n_patches+1, d_model)
# 3. Adding positional embeddings
x = x + self.pos_embed
x = self.dropout(x)
# 4. Passing through the Transformer encoder
x = self.transformer_encoder(x) # (batch, n_patches+1, d_model)
# 5. Extracting the CLS token and classification
cls_output = x[:, 0] # First token = CLS
cls_output = self.norm(cls_output)
logits = self.head(cls_output)
return logits
Minimal Training Function
def train_vit(model, dataloader, criterion, optimizer, device):
"""Training loop for one epoch."""
model.train()
total_loss = 0
correct = 0
total = 0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
_, predicted = logits.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
avg_loss = total_loss / total
accuracy = 100.0 * correct / total
return avg_loss, accuracy
ResNet vs Vision Transformer Comparison
| Aspect | ResNet-50 (CNN) | ViT-B/16 (Transformer) |
|---|---|---|
| Fundamental operation | Local convolution | Global attention |
| Receptive field | Progressive (increases with depth) | Global from the 1st layer |
| Inductive bias | Strong (locality, translation) | Weak (few a priori assumptions) |
| Parameters | ~25 million | ~86 million (Base) |
| FLOPs | ~4.1 G | ~17.6 G |
| Data required | Moderate (ImageNet is enough) | Enormous (JFT-300M, ImageNet-21K) |
| Parallelization | Good (localized convolutions) | Excellent (dense matrices) |
| Interpretability | Visualizable filters | Explicit attention maps |
This comparison reveals a fundamental trade-off: the ResNet, thanks to its strong inductive bias (the assumption that neighboring pixels are correlated), generalizes well with little data. The ViT, with its weaker inductive bias, needs much more data to learn these regularities by itself, but in return, it achieves superior performance once trained at scale.
Key Hyperparameters of the Vision Transformer
Tuning hyperparameters is crucial to achieving good performance with a Vision Transformer architecture. Here are the main levers to adjust.
Patch Size (patch_size)
The patch size ( P ) determines the number of patches and thus the resolution of the input sequence. The most common values are 8, 14, and 16.
- Smaller patch (P = 8): more patches, thus more details captured, but quadratic complexity in ( O(N^2) ) and risk of overfitting.
- Larger patch (P = 16 or 32): fewer patches, thus faster computation and less data needed, but loss of fine details.
In practice, P = 16 is an excellent compromise for most applications. The ViT-B/16 model (Base, patch 16) is the most widely used variant in the literature and in industrial applications.
Hidden Dimension (hidden_size / d_model)
The dimension of the representation space, denoted ( d_{\text{model}} ), controls the model’s expressive capacity. The standard configurations are:
- ViT-Tiny: ( d_{\text{model}} = 192 ), ~5 million parameters
- ViT-Small: ( d_{\text{model}} = 384 ), ~22 million parameters
- ViT-Base: ( d_{\text{model}} = 768 ), ~86 million parameters
- ViT-Large: ( d_{\text{model}} = 1024 ), ~307 million parameters
- ViT-Huge: ( d_{\text{model}} = 1280 ), ~632 million parameters
Increasing ( d_{\text{model}} ) generally improves performance at the cost of significantly higher computational and memory expense.
Number of Layers (n_layers)
The number of encoder layers ( L ) determines the model’s depth. The usual values are:
- ViT-Tiny: 12 layers
- ViT-Small: 12 layers
- ViT-Base: 12 layers
- ViT-Large: 24 layers
- ViT-Huge: 32 layers
A higher number of layers allows the model to build increasingly abstract representations, similar to depth in CNNs. However, beyond a certain threshold, the marginal returns decrease and the risk of vanishing gradients increases (although residual connections mitigate this problem).
Number of Attention Heads (n_heads)
The number of heads divides the ( d_{\text{model}} ) space into independent subspaces, each learning a different type of attention relationship. Typically:
$$
d_{\text{head}} = \frac{d_{\text{model}}}{n_{\text{heads}}}
$$
For example, with ( d_{\text{model}} = 768 ) and ( n_{\text{heads}} = 12 ), each head operates in a 64-dimensional subspace. Too few heads limit the diversity of learned relationships, while too many dilute each head’s capacity.
Dropout Rate (dropout)
Dropout regularization is essential to avoid overfitting, especially when the ViT is trained from scratch. The recommended values are:
- Training from scratch: dropout = 0.1 to 0.3
- Fine-tuning: dropout = 0.0 to 0.1 (less necessary since the model is already pretrained)
Stochastic depth (randomly dropping entire layers) is also often used as a complementary regularization technique, particularly useful for very deep models.
Advantages and Limitations of the Vision Transformer
Advantages
- Global modeling of spatial relationships: Unlike CNNs whose receptive field is local and only widens progressively, the ViT’s attention mechanism captures dependencies between all pairs of regions in the image, from the first layer. This holistic view is particularly powerful for understanding the global composition of a scene.
- Excellent scalability: Vision Transformer performance improves continuously with training data and model size. Studies show that on datasets containing hundreds of millions of images (such as JFT-300M), ViT consistently surpasses the best convolutional architectures.
- Interpretability via attention maps: Attention weights between patches naturally provide an intuitive visualization of what the model “looks at” to make its decision. Each attention head can be interpreted as focusing on a different aspect of the image.
- Architectural unification: ViT uses the same fundamental building block (the Transformer encoder) as language models. This unification enables the design of multimodal architectures naturally combining text and images, such as CLIP, Flamingo, or image generation models based on Transformers.
- Strong transfer learning: Once pretrained on a large image corpus, ViT transfers remarkably well to target tasks, even with little annotated data available.
Limitations
- Massive data requirement: Without pretraining on gigantic corpora (JFT-300M, ImageNet-21K), ViT struggles to generalize properly. Its weak inductive bias is a disadvantage when data is scarce: it does not “know” a priori that neighboring pixels are correlated or that translation preserves semantics.
- Quadratic computational complexity: The attention matrix computes scores for all pairs of patches, implying a cost of ( O(N^2 \cdot d_{\text{model}}) ). For a ( 224 \times 224 ) image with ( P = 16 ), this represents ( 196^2 \approx 38\,000 ) pairs to evaluate. For high-resolution images, this cost becomes prohibitive.
- Less performant on dense tasks: For pixel-level prediction tasks such as semantic segmentation or object detection, convolutional architectures often retain an advantage in terms of precision and contour fineness, although specific ViT adaptations (such as Segmenter or MaskFormer) are gradually closing this gap.
- Sensitivity to adversarial perturbations: Studies have shown that visual transformers can be more vulnerable to adversarial attacks than CNNs in certain scenarios, particularly when perturbations exploit the global attention structure.
4 Concrete Use Cases of the Vision Transformer
Use Case 1: Medical Image Classification
The Vision Transformer excels in classifying radiographic images, MRI slices, and dermatological photographs. Its ability to capture global relationships is particularly relevant: in a chest X-ray, an anomaly in one region can be correlated with another distant region of the lung, a relationship that a classical CNN might miss if its receptive field is not wide enough. Research teams have demonstrated results superior to CNNs on benchmarks such as CheXpert and MIMIC-CXR for detecting pneumonia and other thoracic pathologies.
Use Case 2: Satellite Scene Classification
Satellite images often cover vast geographical areas where objects of interest are spread across the entire scene. The Vision Transformer, with its global attention, is naturally suited to this configuration: it can establish correlations between distant regions to classify the type of scene (urban area, forest, agricultural zone, body of water). Specialized variants like SatMAE also integrate multi-temporal and multi-spectral data for environmental analysis.
Use Case 3: Document Recognition and Intelligent OCR
For understanding scanned documents (invoices, forms, contracts), ViT allows simultaneous analysis of the global visual structure (layout, columns, tables) and local characteristics (fonts, logos, signatures). Coupled with a language model for processing extracted text, the Vision Transformer forms the basis of end-to-end document processing systems that outperform traditional rule-based pipelines.
Use Case 4: Multimedia Content Verification and Deepfake Detection
Detecting deepfakes and image manipulations is an area where the Transformer’s global vision offers significant advantages. Generation artifacts left by GAN or diffusion models are often subtle and distributed across the entire image. By analyzing global correlations between patches, a Vision Transformer can detect inconsistencies that escape a CNN focused on local patterns. Recent research has shown that ViT achieves detection rates above 95% on deepfake detection benchmarks.
See Also
- How to Use Exceptions in Python?
- Master Mahjong with Python: Complete Guide for Developers and Enthusiasts

