Learn AI Series (#54) - Vision Transformers
What will I learn
- You will learn how Vision Transformers (ViT) treat images as sequences of patches instead of using convolutions;
- patch embedding -- converting a 2D image into a 1D sequence of vectors using a single convolution trick;
- the CLS token -- a learnable summary vector that aggregates information from all patches;
- position embeddings for images and why 1D embeddings learn 2D spatial structure on their own;
- when ViTs beat CNNs and when they don't -- the data hunger tradeoff;
- DeiT and data-efficient training strategies that closed the gap;
- hybrid architectures that combine CNN backbones with transformer heads;
- the Swin Transformer and why windowed attention is O(n) instead of O(n^2).
Requirements
- A working modern computer running macOS, Windows or Ubuntu;
- An installed Python 3(.11+) distribution;
- The ambition to learn AI and machine learning.
Difficulty
- Beginner
Curriculum (of the Learn AI Series):
- Learn AI Series (#1) - What Machine Learning Actually Is
- Learn AI Series (#2) - Setting Up Your AI Workbench - Python and NumPy
- Learn AI Series (#3) - Your Data Is Just Numbers - How Machines See the World
- Learn AI Series (#4) - Your First Prediction - No Math, Just Intuition
- Learn AI Series (#5) - Patterns in Data - What "Learning" Actually Looks Like
- Learn AI Series (#6) - From Intuition to Math - Why We Need Formulas
- Learn AI Series (#7) - The Training Loop - See It Work Step by Step
- Learn AI Series (#8) - The Math You Actually Need (Part 1) - Linear Algebra
- Learn AI Series (#9) - The Math You Actually Need (Part 2) - Calculus and Probability
- Learn AI Series (#10) - Your First ML Model - Linear Regression From Scratch
- Learn AI Series (#11) - Making Linear Regression Real
- Learn AI Series (#12) - Classification - Logistic Regression From Scratch
- Learn AI Series (#13) - Evaluation - How to Know If Your Model Actually Works
- Learn AI Series (#14) - Data Preparation - The 80% Nobody Talks About
- Learn AI Series (#15) - Feature Engineering and Selection
- Learn AI Series (#16) - Scikit-Learn - The Standard Library of ML
- Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
- Learn AI Series (#18) - Random Forests - Wisdom of Crowds
- Learn AI Series (#19) - Gradient Boosting - The Kaggle Champion
- Learn AI Series (#20) - Support Vector Machines - Drawing the Perfect Boundary
- Learn AI Series (#21) - Mini Project - Predicting Crypto Market Regimes
- Learn AI Series (#22) - K-Means Clustering - Finding Groups
- Learn AI Series (#23) - Advanced Clustering - Beyond K-Means
- Learn AI Series (#24) - Dimensionality Reduction - PCA
- Learn AI Series (#25) - Advanced Dimensionality Reduction - t-SNE and UMAP
- Learn AI Series (#26) - Anomaly Detection - Finding What Doesn't Belong
- Learn AI Series (#27) - Recommendation Systems - "Users Like You Also Liked..."
- Learn AI Series (#28) - Time Series Fundamentals - When Order Matters
- Learn AI Series (#29) - Time Series Forecasting - Predicting What Comes Next
- Learn AI Series (#30) - Natural Language Processing - Text as Data
- Learn AI Series (#31) - Word Embeddings - Meaning in Numbers
- Learn AI Series (#32) - Bayesian Methods - Thinking in Probabilities
- Learn AI Series (#33) - Ensemble Methods Deep Dive - Stacking and Blending
- Learn AI Series (#34) - ML Engineering - From Notebook to Production
- Learn AI Series (#35) - Data Ethics and Bias in ML
- Learn AI Series (#36) - Mini Project - Complete ML Pipeline
- Learn AI Series (#37) - The Perceptron - Where It All Started
- Learn AI Series (#38) - Neural Networks From Scratch - Forward Pass
- Learn AI Series (#39) - Neural Networks From Scratch - Backpropagation
- Learn AI Series (#40) - Training Neural Networks - Practical Challenges
- Learn AI Series (#41) - Optimization Algorithms - SGD, Momentum, Adam
- Learn AI Series (#42) - PyTorch Fundamentals - Tensors and Autograd
- Learn AI Series (#43) - PyTorch Data and Training
- Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks
- Learn AI Series (#45) - Convolutional Neural Networks - Theory
- Learn AI Series (#46) - CNNs in Practice - Classic to Modern Architectures
- Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer
- Learn AI Series (#48) - Recurrent Neural Networks - Sequences
- Learn AI Series (#49) - LSTM and GRU - Solving the Memory Problem
- Learn AI Series (#50) - Sequence-to-Sequence Models
- Learn AI Series (#51) - Attention Mechanisms
- Learn AI Series (#52) - The Transformer Architecture (Part 1)
- Learn AI Series (#53) - The Transformer Architecture (Part 2)
- Learn AI Series (#54) - Vision Transformers (this post)
Learn AI Series (#54) - Vision Transformers
Solutions to Episode #53 Exercises
Exercise 1: Training loop for the complete Transformer on sequence reversal.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import TensorDataset, DataLoader
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_out = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch = Q.size(0)
Q = self.W_q(Q).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
out = (weights @ V).transpose(1, 2).contiguous().view(
batch, -1, self.n_heads * self.d_k)
return self.W_out(out)
def make_causal_mask(seq_len):
return torch.tril(torch.ones(seq_len, seq_len))
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).unsqueeze(1).float()
div = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class TransformerEncoderBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, mask=None):
x = self.norm1(x + self.drop(self.attn(x, x, x, mask)))
x = self.norm2(x + self.drop(self.ff(x)))
return x
class TransformerDecoderBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.cross_attn = MultiHeadAttention(d_model, n_heads)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, enc_out, causal_mask=None, src_mask=None):
sa = self.self_attn(x, x, x, causal_mask)
x = self.norm1(x + self.drop(sa))
ca = self.cross_attn(x, enc_out, enc_out, src_mask)
x = self.norm2(x + self.drop(ca))
x = self.norm3(x + self.drop(self.ff(x)))
return x
class Transformer(nn.Module):
def __init__(self, src_vocab, tgt_vocab, d_model=64, n_heads=4,
n_enc=2, n_dec=2, d_ff=128, max_len=512):
super().__init__()
self.d_model = d_model
self.src_emb = nn.Embedding(src_vocab, d_model)
self.tgt_emb = nn.Embedding(tgt_vocab, d_model)
self.pe = PositionalEncoding(d_model, max_len)
self.enc_layers = nn.ModuleList(
[TransformerEncoderBlock(d_model, n_heads, d_ff) for _ in range(n_enc)])
self.dec_layers = nn.ModuleList(
[TransformerDecoderBlock(d_model, n_heads, d_ff) for _ in range(n_dec)])
self.out_proj = nn.Linear(d_model, tgt_vocab)
self.scale = math.sqrt(d_model)
def encode(self, src, mask=None):
x = self.pe(self.src_emb(src) * self.scale)
for layer in self.enc_layers:
x = layer(x, mask)
return x
def decode(self, tgt, enc_out, causal_mask, src_mask=None):
x = self.pe(self.tgt_emb(tgt) * self.scale)
for layer in self.dec_layers:
x = layer(x, enc_out, causal_mask, src_mask)
return x
def forward(self, src, tgt):
causal_mask = make_causal_mask(tgt.size(1)).to(tgt.device)
enc_out = self.encode(src)
dec_out = self.decode(tgt, enc_out, causal_mask)
return self.out_proj(dec_out)
# Training loop
torch.manual_seed(42)
vocab_sz = 20
n_train, n_test = 3000, 500
seq_len = 10
src = torch.randint(2, vocab_sz, (n_train + n_test, seq_len))
tgt = src.flip(1)
# Prepend SOS token (id=0) to target
sos = torch.zeros(n_train + n_test, 1, dtype=torch.long)
tgt_input = torch.cat([sos, tgt[:, :-1]], dim=1) # shifted right
tgt_output = tgt # what we predict
X_tr, X_te = src[:n_train], src[n_train:]
y_in_tr, y_in_te = tgt_input[:n_train], tgt_input[n_train:]
y_out_tr, y_out_te = tgt_output[:n_train], tgt_output[n_train:]
loader = DataLoader(TensorDataset(X_tr, y_in_tr, y_out_tr),
batch_size=64, shuffle=True)
model = Transformer(src_vocab=vocab_sz, tgt_vocab=vocab_sz,
d_model=64, n_heads=4, n_enc=2, n_dec=2, d_ff=128)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(30):
model.train()
total_loss = 0
for sb, tb_in, tb_out in loader:
logits = model(sb, tb_in)
loss = nn.CrossEntropyLoss()(logits.reshape(-1, vocab_sz),
tb_out.reshape(-1))
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0 or epoch == 29:
model.eval()
with torch.no_grad():
test_logits = model(X_te, y_in_te)
preds = test_logits.argmax(-1)
tok_acc = (preds == y_out_te).float().mean().item()
seq_acc = (preds == y_out_te).all(dim=1).float().mean().item()
print(f"Epoch {epoch:>2d}: tok_acc={tok_acc:.1%}, seq_acc={seq_acc:.1%}")
The transformer should reach high token accuracy (95%+) and respectable sequence accuracy within 30 epochs on this task. Compared to the LSTM seq2seq from episode #50, the transformer typically converges faster and reaches higher final accuracy -- it has direct access to all source positions through cross-attention at every layer, while the LSTM compresses everything through a sequential hidden state chain.
Exercise 2: Greedy decoding and beam search for the trained Transformer.
@torch.no_grad()
def greedy_decode(model, src, max_len=10, sos_id=0):
model.eval()
enc_out = model.encode(src)
tgt_ids = torch.full((src.size(0), 1), sos_id, dtype=torch.long)
for _ in range(max_len):
mask = make_causal_mask(tgt_ids.size(1)).to(src.device)
dec_out = model.decode(tgt_ids, enc_out, mask)
logits = model.out_proj(dec_out[:, -1, :])
next_tok = logits.argmax(-1, keepdim=True)
tgt_ids = torch.cat([tgt_ids, next_tok], dim=1)
return tgt_ids[:, 1:] # strip SOS
@torch.no_grad()
def beam_search(model, src_single, beam_width=3, max_len=10, sos_id=0):
model.eval()
enc_out = model.encode(src_single.unsqueeze(0))
beams = [(torch.tensor([[sos_id]]), 0.0)]
for step in range(max_len):
candidates = []
for seq, score in beams:
mask = make_causal_mask(seq.size(1))
dec_out = model.decode(seq, enc_out, mask)
logits = model.out_proj(dec_out[:, -1, :])
log_probs = F.log_softmax(logits, dim=-1)
topk_lp, topk_idx = log_probs.topk(beam_width)
for i in range(beam_width):
new_seq = torch.cat([seq, topk_idx[:, i:i+1]], dim=1)
new_score = score + topk_lp[0, i].item()
candidates.append((new_seq, new_score))
beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
best_seq = beams[0][0][:, 1:]
return best_seq.squeeze(0)
# Compare on 20 test sequences
test_src = X_te[:20]
test_tgt = y_out_te[:20]
greedy_preds = greedy_decode(model, test_src, max_len=seq_len)
greedy_correct = (greedy_preds == test_tgt).all(dim=1).sum().item()
beam_correct = 0
for i in range(20):
bp = beam_search(model, test_src[i], beam_width=3, max_len=seq_len)
if (bp[:seq_len] == test_tgt[i]).all():
beam_correct += 1
print(f"Greedy: {greedy_correct}/20 perfect sequences")
print(f"Beam (width=3): {beam_correct}/20 perfect sequences")
# Side-by-side for 5 examples
print(f"\n{'Source':>30s} {'Target':>30s} {'Greedy':>30s}")
for i in range(5):
s = test_src[i].tolist()
t = test_tgt[i].tolist()
g = greedy_preds[i].tolist()
print(f"{str(s):>30s} {str(t):>30s} {str(g):>30s}")
For the short reversal task (length 10), greedy and beam search perform very similarly. Beam search shines on longer or more ambiguous sequences where early token mistakes cascade -- the ability to maintain multiple hypotheses lets it recover from locally suboptimal choices.
Exercise 3: KV cache for faster inference.
import time
@torch.no_grad()
def greedy_no_cache(model, src, max_len=50, sos_id=0):
"""Standard greedy: re-encodes full target at every step."""
model.eval()
enc_out = model.encode(src)
tgt_ids = torch.full((src.size(0), 1), sos_id, dtype=torch.long)
for _ in range(max_len):
mask = make_causal_mask(tgt_ids.size(1))
dec_out = model.decode(tgt_ids, enc_out, mask)
logits = model.out_proj(dec_out[:, -1, :])
next_tok = logits.argmax(-1, keepdim=True)
tgt_ids = torch.cat([tgt_ids, next_tok], dim=1)
return tgt_ids[:, 1:]
@torch.no_grad()
def greedy_with_cache(model, src, max_len=50, sos_id=0):
"""Cached: stores K,V from previous steps, only computes new token."""
model.eval()
enc_out = model.encode(src)
batch = src.size(0)
tgt_ids = torch.full((batch, 1), sos_id, dtype=torch.long)
# Pre-compute cross-attention K,V from encoder output (constant)
cached_kv = [{} for _ in model.dec_layers]
for step in range(max_len):
# Only embed the newest token
if step == 0:
x = model.pe(model.tgt_emb(tgt_ids) * model.scale)
else:
new_tok = tgt_ids[:, -1:]
emb = model.tgt_emb(new_tok) * model.scale
emb = emb + model.pe.pe[:, step:step+1]
x = emb
for i, layer in enumerate(model.dec_layers):
# Self-attention: accumulate K,V
sq = layer.self_attn.W_q(x)
sk = layer.self_attn.W_k(x)
sv = layer.self_attn.W_v(x)
if 'self_k' not in cached_kv[i]:
cached_kv[i]['self_k'] = sk
cached_kv[i]['self_v'] = sv
else:
cached_kv[i]['self_k'] = torch.cat(
[cached_kv[i]['self_k'], sk], dim=1)
cached_kv[i]['self_v'] = torch.cat(
[cached_kv[i]['self_v'], sv], dim=1)
# Attention using all cached keys/values
d_k = layer.self_attn.d_k
n_h = layer.self_attn.n_heads
Q = sq.view(batch, -1, n_h, d_k).transpose(1, 2)
K = cached_kv[i]['self_k'].view(batch, -1, n_h, d_k).transpose(1, 2)
V = cached_kv[i]['self_v'].view(batch, -1, n_h, d_k).transpose(1, 2)
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
weights = F.softmax(scores, dim=-1)
sa_out = (weights @ V).transpose(1, 2).contiguous().view(
batch, -1, n_h * d_k)
sa_out = layer.self_attn.W_out(sa_out)
x = layer.norm1(x + sa_out)
# Cross-attention (encoder K,V are constant)
ca = layer.cross_attn(x, enc_out, enc_out)
x = layer.norm2(x + ca)
ff = layer.ff(x)
x = layer.norm3(x + ff)
logits = model.out_proj(x[:, -1:, :])
next_tok = logits.argmax(-1)
tgt_ids = torch.cat([tgt_ids, next_tok], dim=1)
return tgt_ids[:, 1:]
# Timing comparison
dummy_src = torch.randint(2, 20, (1, 10))
t0 = time.perf_counter()
for _ in range(5):
_ = greedy_no_cache(model, dummy_src, max_len=50)
t_no = (time.perf_counter() - t0) / 5
t0 = time.perf_counter()
for _ in range(5):
_ = greedy_with_cache(model, dummy_src, max_len=50)
t_cached = (time.perf_counter() - t0) / 5
print(f"No cache: {t_no*1000:.1f}ms")
print(f"With cache: {t_cached*1000:.1f}ms")
print(f"Speedup: {t_no/t_cached:.2f}x")
The KV cache avoids recomputing key-value projections for all previous tokens at every step. Without it, generating token 50 means running the full decoder on a sequence of length 50. With it, you only compute the new token's projections and attend using the cached keys and values. The speedup grows linearly with generation length -- at 50 tokens it should be roughly 10-25x faster. This is why every production inference system (vLLM, TensorRT-LLM, etc.) uses KV caching -- it's the single biggest optimization for autoregressive generation.
On to today's episode
Here we go! For the past three episodes we've been deep in the transformer architecture -- episode #52 built the encoder from scratch (scaled dot-product attention, multi-head attention, positional encoding, feed-forward layers), and episode #53 added the decoder (masked self-attention, cross-attention) and assembled the full encoder-decoder transformer. We saw how this architecture replaced RNNs entirely and became the foundation of every major AI system.
But here's something that might seem surprising. For the first few years after "Attention Is All You Need" was published in 2017, transformers were a language thing. CNNs still owned computer vision. And honestly, that made sense -- CNNs have built-in inductive biases (translation invariance, local connectivity) that match how images work. A cat in the top-left corner looks the same as a cat in the bottom-right corner. Nearby pixels are more related than distant pixels. CNNs encode these priors directly in their architecture (we covered this in episodes #45-47).
Transformers, on the other hand, were designed for sequences. They have no concept of locality. No translation invariance. They treat the input as a set of tokens with learned position embeddings. Why would you use something like that on images?
The answer, as it turned out, was: because they scale better. And scaling, as we've been seeing throughout this series, is the secret sauce behind practically every breakthrough in modern AI ;-)
The core idea: images as sequences
In October 2020, a team at Google published a paper with the straightforwardly titled "An Image Is Worth 16x16 Words." It introduced the Vision Transformer (ViT), and the idea was almost aggressively simple: chop the image into patches, treat each patch as a "word," and feed the resulting sequence to a standard transformer encoder. No convolutions. No pooling. Just attention.
Let's think about why you can't just flatten the raw pixels. A 224x224 pixel image with 3 color channels is a 224x224x3 tensor -- about 150,000 values. You could flatten this into a single sequence of 150,528 tokens, but self-attention is O(n^2) in sequence length. That's 22 billion attention scores per layer. Not practical, even on modern GPUs.
The ViT solution: divide the image into non-overlapping square patches. A 224x224 image with 16x16 patches gives you (224/16)^2 = 196 patches. Each patch is 16x16x3 = 768 values. Now your sequence length is 196 -- very manageable. That's the same order of magnitude as a typical sentence.
Each patch gets linearly projected to d_model dimensions, creating a patch embedding that serves the same role as a word embedding in language transformers:
import torch
import torch.nn as nn
import math
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, d_model=768):
super().__init__()
self.n_patches = (img_size // patch_size) ** 2
# Conv2d with kernel=stride=patch_size: non-overlapping patches
self.proj = nn.Conv2d(
in_channels, d_model,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (batch, channels, height, width)
x = self.proj(x) # (batch, d_model, n_h, n_w)
x = x.flatten(2).transpose(1, 2) # (batch, n_patches, d_model)
return x
pe = PatchEmbedding()
img = torch.randn(1, 3, 224, 224)
patches = pe(img)
print(f"Image shape: {img.shape}")
print(f"Patch embeddings: {patches.shape}") # (1, 196, 768)
print(f"Each patch: 16x16x3 = {16*16*3} values -> {patches.shape[-1]}-dim vector")
The Conv2d with kernel_size=patch_size and stride=patch_size is a neat trick: it processes each non-overlapping patch independently, projecting every 16x16x3 patch into a 768-dimensional vector. Functionally identical to flattening each patch to 768 values and multiplying by a weight matrix, but the convolution approach is more efficient because it lets PyTorch use optimized conv kernels. The output is the same either way -- 196 vectors of dimension 768.
The CLS token and position embeddings
ViT borrows a trick straight from BERT (which we'll study in a future episode): prepend a special learnable [CLS] token to the patch sequence. After passing through the transformer, the CLS token's output serves as the image representation for classification. The idea is that through self-attention, the CLS token can aggregate information from all patches into a single vector -- it "attends to everything" and builds a holistic summary of the image.
Position embeddings tell the transformer where each patch sits in the image. And here's something I found genuinely surprising when I first read the paper: ViT uses standard learned 1D position embeddings (not sinusoidal, not 2D). Just a (197, 768) parameter matrix -- one row per position (196 patches + 1 CLS token) -- that gets added to the patch embeddings:
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.norm2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
)
def forward(self, x):
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.ff(self.norm2(x))
return x
class ViT(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_ch=3, d_model=768,
n_heads=12, n_layers=12, n_classes=1000, d_ff=3072):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_ch, d_model)
n_patches = self.patch_embed.n_patches
# Learnable CLS token and position embeddings
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
self.pos_embed = nn.Parameter(
torch.randn(1, n_patches + 1, d_model) * 0.02)
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, n_classes)
def forward(self, x):
batch = x.size(0)
patches = self.patch_embed(x) # (batch, 196, 768)
cls = self.cls_token.expand(batch, -1, -1) # (batch, 1, 768)
x = torch.cat([cls, patches], dim=1) # (batch, 197, 768)
x = x + self.pos_embed # add position info
for layer in self.layers:
x = layer(x)
x = self.norm(x[:, 0]) # CLS token output only
return self.head(x) # (batch, n_classes)
vit = ViT(img_size=224, patch_size=16, d_model=192,
n_heads=3, n_layers=4, n_classes=10, d_ff=768)
img = torch.randn(2, 3, 224, 224)
logits = vit(img)
print(f"Input: {img.shape}")
print(f"Output logits: {logits.shape}") # (2, 10)
n_params = sum(p.numel() for p in vit.parameters())
print(f"Parameters: {n_params:,}")
A few things worth noticing. First, ViT uses GELU activation in stead of ReLU in the feed-forward blocks -- GELU is smoother (it's roughly x * sigmoid(1.702 * x)) and has become the default in transformers since GPT-2. Second, the position embeddings are learned -- just a parameter matrix that gets optimized during training, not fixed sinusoidal functions like in the original transformer. Third, the classification uses pre-norm (LayerNorm before each sub-layer), which is the modern convention we discussed in episode #53. And fourth, only the CLS token's output goes to the classification head -- the other 196 patch positions are used during attention but discarded for the final prediction.
Does 1D position embedding learn 2D structure?
This is genuinely remarkable (and I remember being quit impressed when I first read about it). Even though the position embeddings are 1D -- just indices 0 through 196 with no explicit notion of "row" or "column" -- the learned embeddings develop 2D spatial structure during training.
If you visualize the cosine similarity between position embeddings after training, you see that neighboring patches in the 2D grid have similar embeddings. Patch at position (0, 0) has a similar embedding to patch at position (0, 1) and (1, 0), but a very different embedding from the patch at position (13, 13). The model discovers the 2D grid layout purely from data:
# After training, position embeddings show 2D structure.
# Let's simulate what this looks like conceptually.
n_per_side = 14 # sqrt(196) = 14x14 grid
total = n_per_side ** 2
# In a trained ViT, position embedding similarity looks like this:
# pos_embed[row*14 + col] is similar to pos_embed[row*14 + col+1] (right neighbor)
# pos_embed[row*14 + col] is similar to pos_embed[(row+1)*14 + col] (bottom neighbor)
# pos_embed[row*14 + col] is dissimilar to pos_embed[(13-row)*14 + (13-col)] (far corner)
print(f"Patch grid: {n_per_side}x{n_per_side} = {total} patches")
print(f"Position 0 (top-left) neighbors: 1 (right), {n_per_side} (below)")
print(f"Position 97 (row 6, col 13) neighbors: 96 (left), {97+n_per_side} (below)")
print(f"Position 0 is far from position {total-1} (bottom-right corner)")
print()
print("The model learns this spatial structure without ever being told")
print("that patches form a 2D grid. It figures it out from the data alone.")
This is a recurring theme in deep learning that we've seen multiple times throughout this series: give the model minimal inductive bias, enough data, and it discovers the relevant structure on its own. CNNs have 2D locality baked in by design -- the convolutional kernel explicitly processes spatially adjacent pixels. ViTs learn locality from data. The question is whether the data cost of learning something so "obvious" is worth the flexibility.
Having said that, the paper also tested 2D-aware position embeddings (encoding row and column separately) and found... basically no improvement. The 1D embeddings work just as well. The model doesn't need help discovering 2D structure.
The data hunger problem
This is where the ViT story gets nuanced, and it's important to understand because it determines when you should actually use a ViT in practice.
The original ViT paper found that ViT trained on ImageNet alone (1.3 million images) performed worse than a comparable CNN (ResNet). CNNs have built-in inductive biases -- local connectivity, translation equivariance -- that help them learn efficiently from limited data. These biases act like strong priors: "nearby pixels are related" and "the same pattern matters regardless of where it appears." Without these biases, ViTs need more examples to discover the same spatial structure from scratch.
But when pre-trained on JFT-300M (300 million images, Google's internal dataset) and then fine-tuned on ImageNet, ViT crushed every CNN. The pattern was clear: with enough data, the lack of inductive bias becomes an advantage because the model isn't constrained by assumptions that might not perfectly hold.
# Let's demonstrate the effect of dataset size on ViT vs CNN
# (conceptual -- real training would need GPUs and hours)
# A small ViT and a comparable CNN for CIFAR-10
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
# Small ViT for 32x32 images
small_vit = ViT(img_size=32, patch_size=4, d_model=128,
n_heads=4, n_layers=4, n_classes=10, d_ff=512)
vit_params = sum(p.numel() for p in small_vit.parameters())
# Comparable CNN
cnn = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
nn.Flatten(), nn.Linear(256, 10)
)
cnn_params = sum(p.numel() for p in cnn.parameters())
print(f"ViT params: {vit_params:>10,}")
print(f"CNN params: {cnn_params:>10,}")
print()
print("On small datasets (< 10K images): CNN wins")
print("On medium datasets (10K-100K): roughly tied")
print("On large datasets (1M+): ViT wins")
print("On huge datasets (100M+): ViT wins decisively")
The practical implication: if you have limited data (which is most real-world scenarios), a CNN or a pre-trained ViT is the better choice. Training a ViT from scratch on small datasets is a recipe for overfitting -- the model has too many degrees of freedom and not enough data to constrain them. But fine-tuning a ViT that was pre-trained on a massive dataset? That works beautifully, because the pre-training has already taught the model about spatial structure, edges, textures, and objects. Fine-tuning just adapts that knowledge to your specific task.
DeiT: data-efficient image transformers
The data hunger problem bothered Facebook's (Meta's) research team. In 2021, they published DeiT (Data-efficient Image Transformers), which closed the gap between ViTs and CNNs on ImageNet-scale data through better training recipes -- no extra data required.
The key ingredients:
Stronger data augmentation: RandAugment, random erasing, CutMix, Mixup -- aggressively transforming training images to artificially increase dataset diversity. We touched on data augmentation in episode #45 when discussing CNN training. DeiT takes it much further.
Regularization: stochastic depth (randomly dropping entire transformer layers during training -- conceptually similar to Dropout but applied to whole layers), repeated augmentation (the same image gets multiple different augmentations within the same batch), and label smoothing (softening hard one-hot labels to prevent overconfidence).
Knowledge distillation: a pre-trained CNN teacher (like a RegNet) guides the ViT student through an extra distillation token -- similar to the CLS token, but trained to match the teacher's predictions in stead of the ground truth labels. The distillation token gets its own classification head:
class DeiTClassifier(nn.Module):
"""Simplified DeiT: CLS token + distillation token."""
def __init__(self, d_model=768, n_classes=1000):
super().__init__()
# Two separate classification heads
self.cls_head = nn.Linear(d_model, n_classes) # trained on labels
self.dist_head = nn.Linear(d_model, n_classes) # trained on teacher
def forward(self, cls_output, dist_output, teacher_logits=None):
cls_logits = self.cls_head(cls_output)
dist_logits = self.dist_head(dist_output)
if self.training and teacher_logits is not None:
# Hard distillation: student matches teacher's argmax
teacher_labels = teacher_logits.argmax(dim=-1)
dist_loss = nn.CrossEntropyLoss()(dist_logits, teacher_labels)
cls_loss = nn.CrossEntropyLoss()(
cls_logits,
torch.zeros(cls_logits.size(0), dtype=torch.long) # placeholder
)
return cls_logits, dist_logits, dist_loss
else:
# At inference: average both heads
return (cls_logits + dist_logits) / 2
classifier = DeiTClassifier(d_model=192, n_classes=10)
cls_out = torch.randn(4, 192) # CLS token output
dist_out = torch.randn(4, 192) # distillation token output
logits = classifier(cls_out, dist_out)
print(f"Combined logits: {logits.shape}") # (4, 10)
The result: DeiT-B (86M parameters) matched ResNet-152 (60M parameters) on ImageNet -- with training ONLY on ImageNet. No JFT, no 300M images. Just smarter training. The lesson is one we keep returning to in this series: training recipe matters as much as architecture. The same model with naive training and with DeiT's recipe can differ by several percentage points in accuracy.
Hybrid architectures: the best of both worlds
The either-or framing (CNN vs transformer) turns out to be a false choice. Some of the best-performing architectures combine both approaches, and this is honestly where the practical state of the art is for most applications:
CNN backbone + transformer head: use a CNN (like ResNet-50) to extract feature maps from the image, then treat those feature maps as a sequence and process with transformer layers. The CNN handles low-level feature extraction efficiently (edges, textures, simple shapes -- stuff where locality really matters); the transformer handles global relationships (the cat's ear is related to the cat's tail, even though they're far apart):
class HybridViT(nn.Module):
def __init__(self, n_classes=1000, d_model=768, n_heads=12, n_layers=6):
super().__init__()
# CNN extracts spatial features: 224x224x3 -> 7x7x2048
import torchvision.models as models
resnet = models.resnet50(weights=None)
self.backbone = nn.Sequential(*list(resnet.children())[:-2])
# Project CNN features to transformer dimension
self.proj = nn.Linear(2048, d_model)
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
self.pos_embed = nn.Parameter(
torch.randn(1, 50, d_model) * 0.02) # 49 patches + 1 CLS
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_model * 4)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, n_classes)
def forward(self, x):
batch = x.size(0)
feats = self.backbone(x) # (batch, 2048, 7, 7)
feats = feats.flatten(2).transpose(1, 2) # (batch, 49, 2048)
feats = self.proj(feats) # (batch, 49, 768)
cls = self.cls_token.expand(batch, -1, -1)
x = torch.cat([cls, feats], dim=1) # (batch, 50, 768)
x = x + self.pos_embed
for layer in self.layers:
x = layer(x)
return self.head(self.norm(x[:, 0]))
hybrid = HybridViT(n_classes=10, d_model=192, n_heads=3, n_layers=2)
img = torch.randn(2, 3, 224, 224)
out = hybrid(img)
print(f"Hybrid ViT output: {out.shape}") # (2, 10)
n_params = sum(p.numel() for p in hybrid.parameters())
print(f"Parameters: {n_params:,}")
The hybrid gets the best of both worlds: CNN's efficient local feature extraction for the early stages, transformer's global attention for the later stages. In practice, this often trains faster and needs less data than a pure ViT, because the CNN backbone already encodes useful spatial priors. You're not asking the transformer to learn "nearby pixels are related" -- the CNN has already handled that.
Convolutional patch embedding: a simpler hybrid approach. Replace ViT's single linear patch projection with a small CNN that processes patches with multiple convolutional layers before projecting to d_model. This gives the model some local inductive bias right at the input stage, which particularly helps with smaller datasets.
The Swin Transformer: hierarchical vision
One of the biggest practical advances in vision transformers came from Microsoft Research in 2021: the Swin Transformer (Shifted WINdow). It addresses the O(n^2) complexity of self-attention by computing attention within local windows (similar to how a convolution has a local receptive field), then shifting those windows across layers to enable cross-window communication:
class WindowAttention(nn.Module):
"""Simplified window-based self-attention."""
def __init__(self, d_model, n_heads, window_size=7):
super().__init__()
self.window_size = window_size
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.qkv = nn.Linear(d_model, d_model * 3)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x, H, W):
B, N, C = x.shape
ws = self.window_size
# Reshape to 2D grid and partition into windows
x = x.view(B, H, W, C)
# Pad if needed
pad_h = (ws - H % ws) % ws
pad_w = (ws - W % ws) % ws
if pad_h > 0 or pad_w > 0:
x = nn.functional.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = x.shape[1], x.shape[2]
nH, nW = Hp // ws, Wp // ws
# (B, nH, ws, nW, ws, C) -> (B*nH*nW, ws*ws, C)
windows = x.view(B, nH, ws, nW, ws, C)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, ws * ws, C)
# Standard attention within each window
qkv = self.qkv(windows).reshape(-1, ws*ws, 3, self.n_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(-1, ws*ws, C)
out = self.proj(out)
# Reverse the windowing
out = out.view(B, nH, nW, ws, ws, C)
out = out.permute(0, 1, 3, 2, 4, 5).contiguous()
out = out.view(B, Hp, Wp, C)
if pad_h > 0 or pad_w > 0:
out = out[:, :H, :W, :]
return out.view(B, H * W, C)
wa = WindowAttention(d_model=96, n_heads=3, window_size=7)
x = torch.randn(2, 56 * 56, 96) # 56x56 feature map, 96 channels
out = wa(x, H=56, W=56)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}")
print(f"Window size: 7x7 = 49 tokens per window")
print(f"Number of windows: {(56//7) * (56//7)} = {(56//7)**2}")
print(f"Attention cost: O(49^2 * {(56//7)**2}) vs full O({56*56}^2)")
With 7x7 windows on a 56x56 feature map, each window contains 49 tokens. Self-attention within a window is O(49^2) = 2,401 operations. There are 64 windows, so total cost is 64 * 2,401 = ~154K operations. Compare that to full self-attention: O(3136^2) = ~9.8 million operations. That's a 64x reduction for this resolution alone, and the gap widens at higher resolutions.
The "shifted" part is what makes it work across windows. In alternating layers, the windows are shifted by half a window size (3 or 4 pixels), so patches that were at the edge of one window in layer L end up in the middle of a different window in layer L+1. Information gradually propagates across the entire image through this shifting mechanism.
Swin also uses a hierarchical structure -- it progressivly merges patches (similar to how CNNs use pooling to reduce spatial resolution), creating multi-scale feature maps. This makes it directly usable for dense prediction tasks like object detection and segmentation, where you need features at multiple scales.
Swin Transformer dominated computer vision benchmarks in 2021-2022 and remains one of the strongest general-purpose vision architectures.
When to use what -- a practical guide
After covering all these variants, let me give you my honest, practical advice (because I know the choice can be paralysing):
Use a pre-trained ViT when: you have a standard image classification problem and can leverage a model pre-trained on a large dataset (ImageNet-21k or similar). Fine-tuning a pre-trained ViT is competitive with or better than fine-tuning a pre-trained CNN. Hugging Face has dozens of pre-trained ViT checkpoints you can download and fine-tune in 20 lines of code.
Use a CNN when: your dataset is small (< 10K images), you're training from scratch, or you need specific architectural properties like exact translation equivariance. Also when you need to run on edge devices -- small CNNs (MobileNet, EfficientNet-B0) are still much more efficient than ViTs for on-device inference.
Use Swin or a hybrid when: you need dense prediction (object detection, segmentation, instance segmentation) or you want maximum performance and can afford the engineering complexity.
Don't overthink it: for most practical problems, the choice between a pre-trained ViT and a pre-trained CNN matters less than proper data augmentation, learning rate scheduling, and hyperparameter tuning. I argue that the majority of practical computer vision projects would get more benefit from spending an extra hour on data quality than from switching between architectures.
# Quick comparison: pre-trained model sizes and speeds
architectures = {
"ResNet-50": {"params": "25.6M", "top1": "76.1%", "best_for": "edge, small data"},
"EfficientNet-B4": {"params": "19.3M", "top1": "82.9%", "best_for": "efficiency"},
"ViT-B/16": {"params": "86.6M", "top1": "84.5%", "best_for": "fine-tuning"},
"DeiT-B": {"params": "86.6M", "top1": "83.1%", "best_for": "no extra data"},
"Swin-B": {"params": "87.8M", "top1": "85.2%", "best_for": "detection/seg"},
}
print(f"{'Architecture':<20s} {'Params':>8s} {'Top-1':>8s} {'Best for':<25s}")
print("-" * 65)
for name, info in architectures.items():
print(f"{name:<20s} {info['params']:>8s} {info['top1']:>8s} {info['best_for']:<25s}")
What comes next
We've now covered the three dominant paradigms in neural network architecture: CNNs (episodes #45-47) for spatial data, RNNs/LSTMs (episodes #48-49) for sequential data, and transformers (episodes #52-54) for... well, everything. The transformer's flexibility means it can process images (ViT, today), text (the original transformer), audio, video, and even protein structures. It's the closest thing we have to a universal architecture.
The natural question is: if transformers are this general, what happens when you scale them up and train them on enormous amounts of data? And what happens when you use them not just for classification, but for generation? Those questions lead to some fascinating territory that we'll explore soon enough. The pieces are in place.
The bottom line
- Vision Transformers treat images as sequences of patches (typically 16x16), linearly projected to embedding vectors -- turning a 2D image into a 1D sequence that a standard transformer can process;
- A learnable CLS token is prepended to the sequence and used for classification after the final transformer layer -- it aggregates information from all patches through self-attention;
- Learned 1D position embeddings discover 2D spatial structure during training -- no explicit 2D encoding needed. The model figures out the grid layout from data alone;
- ViTs need more data than CNNs to train from scratch (lack of inductive bias) but outperform CNNs when pre-trained on large datasets -- the tradeoff is flexibility vs data efficiency;
- DeiT showed that better training recipes (augmentation, regularization, distillation from a CNN teacher) close the data efficiency gap without extra data;
- Hybrid architectures (CNN backbone + transformer head) combine local feature extraction with global attention -- often the most practical choice;
- The Swin Transformer uses windowed attention with shifting for O(n) complexity and hierarchical feature maps -- dominant for detection and segmentation tasks;
- For practitioners: use pre-trained models and fine-tune. The architecture choice matters less than training strategy and data quality ;-)
Exercises
Exercise 1: Build a complete ViT for CIFAR-10 and compare it against a CNN baseline. Create a ViT with img_size=32, patch_size=4 (so 64 patches), d_model=128, n_heads=4, n_layers=4, d_ff=512. For the CNN, use a simple 3-layer ConvNet (Conv->ReLU->Pool repeated 3 times, then a linear head). Train both on CIFAR-10 for 20 epochs with the same optimizer (Adam, lr=1e-3) and batch size (128). Compare training curves and final test accuracy. The CNN should learn faster initially, but with data augmentation (random horizontal flip + random crop) the ViT should close the gap. Print epoch-by-epoch train loss and test accuracy for both models.
Exercise 2: Implement patch embedding visualization. Create a PatchEmbedding module with img_size=32, patch_size=4. Load a single CIFAR-10 image, pass it through the patch embedding, and reconstruct the image from patches. Specifically: (a) show the original image dimensions, (b) show the number and size of patches, (c) for each of the 64 patches, extract the original 4x4x3 pixel block and verify it matches the patch embedding input, (d) visualize the position embedding similarity matrix (65x65 cosine similarity, including CLS token) for a randomly initialized ViT and print which position pairs have the highest similarity -- before training they should be roughly random. Save the similarity matrix to a file.
Exercise 3: Build a window attention module and compare its computational cost against full self-attention. Implement the WindowAttention class from this episode. Create a feature map of shape (1, 56*56, 96) and measure wall-clock time for: (a) full self-attention across all 3136 positions, and (b) window attention with window_size=7 (64 windows of 49 tokens each). Run each 10 times and report average time. Also compute the theoretical FLOPs ratio: full attention is O(N^2 * d) where N=3136, window attention is O(num_windows * w^2 * d) where w=49 and num_windows=64. Print both the theoretical and measured speedup factors.