Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer
What will I learn
- You will learn image classification beyond basic accuracy -- top-K predictions and confidence scores;
- object detection concepts -- region proposals, anchor boxes, and why detection is harder than classification;
- YOLO -- real-time object detection in a single forward pass;
- image segmentation -- pixel-level classification with encoder-decoder architectures;
- neural style transfer -- separating and recombining content and style using CNN features;
- building a practical image classifier and feature extractor from pretrained models.
Requirements
- A working modern computer running macOS, Windows or Ubuntu;
- An installed Python 3(.11+) distribution;
- The ambition to learn AI and machine learning.
Difficulty
- Beginner
Curriculum (of the Learn AI Series):
- Learn AI Series (#1) - What Machine Learning Actually Is
- Learn AI Series (#2) - Setting Up Your AI Workbench - Python and NumPy
- Learn AI Series (#3) - Your Data Is Just Numbers - How Machines See the World
- Learn AI Series (#4) - Your First Prediction - No Math, Just Intuition
- Learn AI Series (#5) - Patterns in Data - What "Learning" Actually Looks Like
- Learn AI Series (#6) - From Intuition to Math - Why We Need Formulas
- Learn AI Series (#7) - The Training Loop - See It Work Step by Step
- Learn AI Series (#8) - The Math You Actually Need (Part 1) - Linear Algebra
- Learn AI Series (#9) - The Math You Actually Need (Part 2) - Calculus and Probability
- Learn AI Series (#10) - Your First ML Model - Linear Regression From Scratch
- Learn AI Series (#11) - Making Linear Regression Real
- Learn AI Series (#12) - Classification - Logistic Regression From Scratch
- Learn AI Series (#13) - Evaluation - How to Know If Your Model Actually Works
- Learn AI Series (#14) - Data Preparation - The 80% Nobody Talks About
- Learn AI Series (#15) - Feature Engineering and Selection
- Learn AI Series (#16) - Scikit-Learn - The Standard Library of ML
- Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
- Learn AI Series (#18) - Random Forests - Wisdom of Crowds
- Learn AI Series (#19) - Gradient Boosting - The Kaggle Champion
- Learn AI Series (#20) - Support Vector Machines - Drawing the Perfect Boundary
- Learn AI Series (#21) - Mini Project - Predicting Crypto Market Regimes
- Learn AI Series (#22) - K-Means Clustering - Finding Groups
- Learn AI Series (#23) - Advanced Clustering - Beyond K-Means
- Learn AI Series (#24) - Dimensionality Reduction - PCA
- Learn AI Series (#25) - Advanced Dimensionality Reduction - t-SNE and UMAP
- Learn AI Series (#26) - Anomaly Detection - Finding What Doesn't Belong
- Learn AI Series (#27) - Recommendation Systems - "Users Like You Also Liked..."
- Learn AI Series (#28) - Time Series Fundamentals - When Order Matters
- Learn AI Series (#29) - Time Series Forecasting - Predicting What Comes Next
- Learn AI Series (#30) - Natural Language Processing - Text as Data
- Learn AI Series (#31) - Word Embeddings - Meaning in Numbers
- Learn AI Series (#32) - Bayesian Methods - Thinking in Probabilities
- Learn AI Series (#33) - Ensemble Methods Deep Dive - Stacking and Blending
- Learn AI Series (#34) - ML Engineering - From Notebook to Production
- Learn AI Series (#35) - Data Ethics and Bias in ML
- Learn AI Series (#36) - Mini Project - Complete ML Pipeline
- Learn AI Series (#37) - The Perceptron - Where It All Started
- Learn AI Series (#38) - Neural Networks From Scratch - Forward Pass
- Learn AI Series (#39) - Neural Networks From Scratch - Backpropagation
- Learn AI Series (#40) - Training Neural Networks - Practical Challenges
- Learn AI Series (#41) - Optimization Algorithms - SGD, Momentum, Adam
- Learn AI Series (#42) - PyTorch Fundamentals - Tensors and Autograd
- Learn AI Series (#43) - PyTorch Data and Training
- Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks
- Learn AI Series (#45) - Convolutional Neural Networks - Theory
- Learn AI Series (#46) - CNNs in Practice - Classic to Modern Architectures
- Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer (this post)
Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer
Solutions to Episode #46 Exercises
Exercise 1: Build a mini-VGG network for CIFAR-10 with 3 VGG blocks (two conv+BN+ReLU per block), channels 3->64->128->256, global average pooling, single linear head.
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
class MiniVGG(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
# Block 1: 3 -> 64
nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(2),
# Block 2: 64 -> 128
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.MaxPool2d(2),
# Block 3: 128 -> 256
nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
)
self.classifier = nn.Linear(256, 10)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
model = MiniVGG()
print(f"MiniVGG parameters: {sum(p.numel() for p in model.parameters()):,}")
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
train_data = datasets.CIFAR10('.', train=True, download=True, transform=transform_train)
test_data = datasets.CIFAR10('.', train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=256)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(20):
model.train()
for X, y in train_loader:
loss = loss_fn(model(X), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
if epoch % 5 == 0:
model.eval()
correct = sum(
(model(X).argmax(1) == y).sum().item()
for X, y in test_loader)
print(f"Epoch {epoch}: test_acc={correct/len(test_data):.1%}")
# Expect ~90-92% test accuracy after 20 epochs
The VGG-style double-conv blocks give the network more capacity per spatial scale before pooling reduces dimensions. Global average pooling at the end keeps the parameter count down -- there's only one Linear(256, 10) layer for classification, compared to VGG-16's massive 4096-unit fully connected layers. The classic "fewer params, same receptive field" advantage we discussed.
Exercise 2: Transfer learning comparison -- frozen backbone vs full fine-tuning with differential learning rates.
from torchvision import models
def train_and_eval(model, train_loader, test_loader, optimizer, n_epochs, label):
loss_fn = nn.CrossEntropyLoss()
for epoch in range(n_epochs):
model.train()
for X, y in train_loader:
loss = loss_fn(model(X), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
correct = sum(
(model(X).argmax(1) == y).sum().item()
for X, y in test_loader)
if epoch == n_epochs - 1:
print(f"{label}: final test_acc={correct/len(test_data):.1%}")
# (a) Frozen backbone
resnet_frozen = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
resnet_frozen.fc = nn.Linear(512, 10)
for param in resnet_frozen.parameters():
param.requires_grad = False
for param in resnet_frozen.fc.parameters():
param.requires_grad = True
opt_frozen = torch.optim.Adam(resnet_frozen.fc.parameters(), lr=1e-3)
train_and_eval(resnet_frozen, train_loader, test_loader, opt_frozen, 10, "Frozen backbone")
# (b) Full fine-tuning with differential LRs
resnet_full = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
resnet_full.fc = nn.Linear(512, 10)
opt_full = torch.optim.AdamW([
{'params': resnet_full.conv1.parameters(), 'lr': 1e-5},
{'params': resnet_full.layer1.parameters(), 'lr': 1e-5},
{'params': resnet_full.layer2.parameters(), 'lr': 5e-5},
{'params': resnet_full.layer3.parameters(), 'lr': 1e-4},
{'params': resnet_full.layer4.parameters(), 'lr': 1e-4},
{'params': resnet_full.fc.parameters(), 'lr': 1e-3},
], weight_decay=0.01)
train_and_eval(resnet_full, train_loader, test_loader, opt_full, 10, "Differential LR")
The fully unfrozen model with differential learning rates outperforms the frozen version because it can adapt the pretrained features to CIFAR-10's specific distribution (32x32 images, 10 classes -- quite different from ImageNet's 224x224 with 1000 classes). The frozen version is limited to using ImageNet features as-is, which are still useful but not optimal. The small backbone learning rates protect the pretrained features from being destroyed.
Exercise 3: Function that counts parameters by layer type.
from collections import defaultdict
def count_by_type(model):
stats = defaultdict(lambda: {'count': 0, 'params': 0})
for module in model.modules():
if len(list(module.children())) > 0:
continue # skip container modules
name = module.__class__.__name__
stats[name]['count'] += 1
stats[name]['params'] += sum(p.numel() for p in module.parameters())
return dict(stats)
# Test on MiniVGG
vgg_stats = count_by_type(MiniVGG())
print("MiniVGG layer breakdown:")
for name, info in sorted(vgg_stats.items(), key=lambda x: -x[1]['params']):
if info['params'] > 0:
print(f" {name:>15s}: {info['count']:>3d} layers, {info['params']:>10,} params")
# Test on ResNet-18
resnet_stats = count_by_type(models.resnet18(weights=models.ResNet18_Weights.DEFAULT))
print("\nResNet-18 layer breakdown:")
for name, info in sorted(resnet_stats.items(), key=lambda x: -x[1]['params']):
if info['params'] > 0:
print(f" {name:>15s}: {info['count']:>3d} layers, {info['params']:>10,} params")
In both models, Conv2d layers dominate the parameter count. The Linear layer is typically second, though with global average pooling the final Linear layer is quite small. BatchNorm2d layers have very few parameters (just gamma and beta per channel) despite being present after every convolution. This makes it clear why the debate about "wide vs deep" CNNs is really a debate about how to distribute convolutional parameters.
On to today's episode
We've spent two episodes building a solid understanding of CNNs -- first the theory (episode #45: convolutions, pooling, feature hierarchies, receptive fields), then the practice (episode #46: LeNet through ResNet, transfer learning, training on CIFAR-10). You now know how to build, train, and fine-tune convolutional networks for image classification.
But classification is only the beginning. Telling the diffrence between a cat and a dog is useful, sure. But what if you need to know where the cat is in the image? What if you need to outline exactly which pixels belong to the cat and which are background? What if you want to render your vacation photo in the style of a Van Gogh painting? All of these tasks build on the same CNN feature hierarchies, but apply them in increasingly creative ways. Here we go!
Beyond classification: the vision task spectrum
Let's map out the full range of computer vision tasks, because understanding how they relate to each other is important for picking the right approach:
Image classification: one image -> one label. "This is a cat." The simplest vision task, and what we've focused on in episodes #45-46. The model processes the entire image and outputs a class probability distribution.
Object detection: one image -> multiple bounding boxes, each with a label and a confidence score. "There's a cat at coordinates (50, 80, 200, 300) with 94% confidence, and a dog at (300, 100, 450, 350) with 87% confidence." Much harder -- the model must locate and classify an unknown number of objects simultaneously.
Semantic segmentation: one image -> a class label for every single pixel. "These pixels are cat, these are background, these are furniture." There's no concept of individual objects here -- all cat pixels get the same label regardless of how many cats are in the scene.
Instance segmentation: semantic segmentation + object detection combined. One image -> a class label AND an instance ID for every pixel. "These pixels are cat #1, these are cat #2, these are dog #1." This is the hardest of the four -- it requires simultaneously detecting individual objects, classifying them, and producing pixel-precise masks for each one.
Each task requires more output information from the network, and correspondingly more sophisticated architectures. Classification outputs a vector. Detection outputs boxes. Segmentation outputs a full-resolution label map. Let's look at how the architectures differ.
import torch
import torch.nn as nn
# What each task outputs for a 224x224 input:
print("Vision task output comparison:")
print(f" Classification: (1, num_classes) = (1, 1000)")
print(f" Detection: list of (box, class, score) per object")
print(f" Segmentation: (1, num_classes, 224, 224) = (1, 21, 224, 224)")
print(f" Instance seg: per-object mask + class + score")
# Parameter scale comparison
classifier = nn.Linear(512, 1000)
seg_head = nn.Conv2d(512, 21, 1) # 1x1 conv for per-pixel classification
print(f"\nClassifier head params: {sum(p.numel() for p in classifier.parameters()):,}")
print(f"Segmentation head params: {sum(p.numel() for p in seg_head.parameters()):,}")
Notice something interesting: the segmentation head is actually smaller than the classification head. Segmentation replaces the big nn.Linear with a 1x1 convolution that classifies each spatial position independently. The complexity isn't in the head -- it's in the backbone and the decoder that recover spatial resolution after all the pooling and striding.
Object detection: the two-stage approach
The original approach to object detection was a two-stage pipeline, and understanding it makes the later single-stage methods (like YOLO) much clearer.
R-CNN (Girshick, 2014) was the breakthrough: first propose ~2,000 rectangular regions that might contain objects (using a classical algorithm called selective search), then run a CNN classifier on each cropped region. It worked beautifully but was painfully slow -- 2,000 separate CNN forward passes per image. At test time, processing one image took nearly a minute on a GPU.
Fast R-CNN (2015) improved this dramatically by running the CNN once on the full image to produce a shared feature map, then cropping features for each proposed region from that shared map. One CNN forward pass instead of 2,000. But the region proposal step (selective search) was still slow and ran on CPU.
Faster R-CNN (2015) completed the evolution by replacing selective search with a Region Proposal Network (RPN) -- a small CNN that runs on top of the shared feature map and proposes regions directly. The RPN shares its convolutional features with the classification network, making region proposals nearly free.
The key concept that makes this work: anchor boxes. The RPN places a grid of predefined box shapes at every position in the feature map -- different aspect ratios (tall, square, wide) and different scales (small, medium, large). For each anchor, the network predicts two things: (1) is there an object here? (binary classification) and (2) how should I shift and resize this anchor box to better fit the actual object? (bounding box regression). This converts the open-ended search problem "find all objects and draw boxes around them" into a fixed set of binary classification and regression problems -- which is exactly the kind of thing neural networks are good at.
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
# Load pretrained Faster R-CNN (COCO dataset: 80 object classes)
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
detector = fasterrcnn_resnet50_fpn_v2(weights=weights)
detector.eval()
# COCO class names (subset)
coco_classes = weights.meta["categories"]
print(f"Number of object classes: {len(coco_classes)}")
print(f"First 10 classes: {coco_classes[:10]}")
# Inference on a random image (in practice you'd load a real photo)
x = torch.randn(1, 3, 480, 640)
with torch.no_grad():
predictions = detector(x)
# The output is a list of dicts, one per image
pred = predictions[0]
boxes = pred['boxes'] # (N, 4) bounding box coordinates
scores = pred['scores'] # (N,) confidence scores
labels = pred['labels'] # (N,) class indices
# Filter by confidence threshold
threshold = 0.5
mask = scores > threshold
print(f"\nDetected {mask.sum().item()} objects above {threshold} confidence")
for i in range(min(3, mask.sum().item())):
idx = mask.nonzero()[i].item()
print(f" {coco_classes[labels[idx]]}: score={scores[idx]:.2f}, "
f"box={boxes[idx].tolist()}")
Torchvision provides pretrained Faster R-CNN out of the box. The model uses a ResNet-50 backbone with a Feature Pyramid Network (FPN) -- a clever multi-scale architecture that produces feature maps at different resolutions so the detector can find both small and large objects effectively. On random noise you'll get garbage detections (or none above the threshold), but feed in a real photo and you'll see surprisingly accurate bounding boxes around recognized objects. Fine-tuning for custom object classes follows the same transfer learning pattern we covered in episode #46 -- replace the head, freeze or gently tune the backbone ;-)
YOLO: detection in one shot
The two-stage approach is accurate but adds complexity and latency. YOLO (You Only Look Once, Redmon et al., 2016) collapsed the entire detection pipeline into a single forward pass, and the name really says it all:
The idea: divide the image into an S x S grid (e.g., 13x13). Each grid cell is responsible for predicting a fixed number of bounding boxes. For each box, the cell predicts: position and size (4 values), objectness confidence (1 value), and class probabilities (C values, one per class). One forward pass through the network produces all detections simultaneously. No region proposals, no second stage, no post-hoc classification -- just one shot.
# Conceptual YOLO output structure
# Real YOLO implementations are more complex, but the core idea is:
S = 13 # grid size
B = 3 # boxes per grid cell
C = 80 # number of classes
# Each grid cell predicts B bounding boxes
# Each box has: (x, y, w, h, objectness) + C class probabilities
output_per_cell = B * (5 + C)
total_output = S * S * output_per_cell
print(f"YOLO grid: {S}x{S} = {S*S} cells")
print(f"Boxes per cell: {B}")
print(f"Output per cell: {output_per_cell}")
print(f"Total predictions per image: {S * S * B} boxes")
print(f"Total output tensor size: {total_output}")
# After the forward pass, non-maximum suppression (NMS) filters
# overlapping detections to keep only the best box per object
def simple_nms(boxes, scores, iou_threshold=0.5):
"""Non-maximum suppression: remove overlapping detections."""
order = scores.argsort(descending=True)
keep = []
while len(order) > 0:
i = order[0].item()
keep.append(i)
if len(order) == 1:
break
# Compute IoU between kept box and remaining
remaining = order[1:]
ious = compute_iou(boxes[i].unsqueeze(0), boxes[remaining])
# Keep only boxes with low overlap
order = remaining[ious.squeeze() < iou_threshold]
return keep
def compute_iou(box, boxes):
"""Intersection over Union between one box and multiple boxes."""
x1 = torch.max(box[:, 0], boxes[:, 0])
y1 = torch.max(box[:, 1], boxes[:, 1])
x2 = torch.min(box[:, 2], boxes[:, 2])
y2 = torch.min(box[:, 3], boxes[:, 3])
intersection = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
area_a = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
area_b = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
union = area_a + area_b - intersection
return intersection / (union + 1e-6)
# Demo with dummy data
dummy_boxes = torch.tensor([
[10, 10, 50, 50], # box 1
[12, 12, 52, 52], # box 2 (overlaps heavily with box 1)
[100, 100, 150, 150] # box 3 (no overlap)
], dtype=torch.float)
dummy_scores = torch.tensor([0.9, 0.85, 0.7])
kept = simple_nms(dummy_boxes, dummy_scores, iou_threshold=0.5)
print(f"\nNMS example: {len(dummy_boxes)} input boxes -> {len(kept)} kept")
print(f"Kept indices: {kept}")
The tradeoff between two-stage (Faster R-CNN) and one-stage (YOLO) detectors has largely been resolved in modern versions. YOLO is dramatically faster (real-time, 30+ FPS on consumer GPUs) and has closed the accuracy gap through years of architectural improvements. YOLOv5, YOLOv8, and more recent versions use anchor-free designs, multi-scale feature pyramids, and sophisticated training strategies that rival Faster R-CNN's accuracy while running at 10-50x the speed. For most practical applications in 2026, YOLO variants are the default choice.
Non-maximum suppression (NMS) is critical for both approaches. When the network predicts hundreds of overlapping boxes for the same object, NMS keeps only the highest-confidence box and removes all boxes that overlap with it above an IoU threshold. The compute_iou function calculates the Intersection over Union -- the ratio of the overlap area to the total area covered by both boxes. This is also the standard metric for evaluating detection accuracy: a prediction is "correct" if its IoU with the ground truth box exceeds 0.5.
Image segmentation: classifying every pixel
Classification tells you what is in the image. Segmentation tells you exactly which pixels belong to each class. This requires a fundamentally different output structure -- in stead of a single class vector, the network produces a full-resolution label map where every pixel gets classified.
The architecture challenge is immediately obvious: the conv+pool backbone we've been using reduces spatial dimensions (32x32 -> 16x16 -> 8x8 -> 4x4). That's great for classification (we want to discard spatial detail and keep only semantic information), but terrible for segmentation (we need the spatial detail to produce per-pixel labels). How do you go back up?
Encoder-decoder architecture: the encoder (a standard CNN backbone like ResNet) compresses the image into a low-resolution, high-channel feature map. The decoder (a series of upsampling layers) expands it back to the original resolution. The encoder captures what is in the image; the decoder recovers where.
U-Net (Ronneberger et al., 2015) is the most influential segmentation architecture, and the name comes from its U-shaped structure. The encoder contracts (spatial size shrinks, channels grow), the decoder expands (spatial size grows, channels shrink), and -- this is the crucial part -- skip connections bridge corresponding levels of the encoder and decoder. These skip connections pass fine-grained spatial details (edges, textures, exact boundaries) from the encoder directly to the decoder, compensating for the information lost during pooling.
class UNetBlock(nn.Module):
"""Double convolution block used in U-Net."""
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU())
def forward(self, x):
return self.conv(x)
class MiniUNet(nn.Module):
def __init__(self, n_classes=3):
super().__init__()
# Encoder (downsampling path)
self.enc1 = UNetBlock(3, 64)
self.enc2 = UNetBlock(64, 128)
self.enc3 = UNetBlock(128, 256)
self.pool = nn.MaxPool2d(2)
# Bottleneck
self.bottleneck = UNetBlock(256, 512)
# Decoder (upsampling path)
self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.dec3 = UNetBlock(512, 256) # 512 because of skip connection concat
self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.dec2 = UNetBlock(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec1 = UNetBlock(128, 64)
self.output = nn.Conv2d(64, n_classes, 1) # 1x1 conv for per-pixel class
def forward(self, x):
# Encoder
e1 = self.enc1(x) # (B, 64, H, W)
e2 = self.enc2(self.pool(e1)) # (B, 128, H/2, W/2)
e3 = self.enc3(self.pool(e2)) # (B, 256, H/4, W/4)
# Bottleneck
b = self.bottleneck(self.pool(e3)) # (B, 512, H/8, W/8)
# Decoder with skip connections
d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1)) # concat along channels
d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
return self.output(d1)
unet = MiniUNet(n_classes=3)
x = torch.randn(1, 3, 128, 128)
mask = unet(x)
print(f"Input: {x.shape}")
print(f"Output: {mask.shape}") # (1, 3, 128, 128) -- per-pixel class scores
print(f"Parameters: {sum(p.numel() for p in unet.parameters()):,}")
# Each pixel gets classified independently
predicted_classes = mask.argmax(dim=1)
print(f"Predicted class map: {predicted_classes.shape}") # (1, 128, 128)
The skip connections are what make U-Net work so well. Without them, the decoder would have to reconstruct fine spatial details (sharp edges, thin structures) purely from the low-resolution bottleneck -- which contains rich semantic information but very coarse spatial information. With skip connections, the decoder gets both: high-level semantics from the bottleneck flowing upward, and fine-grained spatial detail from the encoder flowing sideways. This is why U-Net originally excelled at biomedical segmentation where precise boundaries matter (cell boundaries, tumor margins).
The loss function for segmentation is typically per-pixel cross-entropy -- treating each pixel as an independent classification problem. For imbalanced classes (background dominates most images, and the object of interest might cover only 5% of pixels), dice loss or focal loss give better results by focusing the learning signal on the rarer classes:
def dice_loss(pred, target, n_classes=3, smooth=1.0):
"""Dice loss for segmentation -- handles class imbalance better
than cross-entropy by measuring overlap directly."""
pred_softmax = torch.softmax(pred, dim=1)
total_loss = 0
for c in range(n_classes):
pred_c = pred_softmax[:, c]
target_c = (target == c).float()
intersection = (pred_c * target_c).sum()
union = pred_c.sum() + target_c.sum()
dice = (2 * intersection + smooth) / (union + smooth)
total_loss += (1 - dice)
return total_loss / n_classes
# Demo
pred = torch.randn(2, 3, 64, 64) # predictions for 2 images, 3 classes
target = torch.randint(0, 3, (2, 64, 64)) # ground truth pixel labels
loss = dice_loss(pred, target)
print(f"Dice loss: {loss.item():.4f}")
# Compare with standard cross-entropy
ce_loss = nn.CrossEntropyLoss()(pred, target)
print(f"Cross-entropy loss: {ce_loss.item():.4f}")
Having said that, in practice most people just start with nn.CrossEntropyLoss and only switch to dice loss if they notice the model ignoring small or rare classes. Like most things in deep learning, the simpler baseline works more often than you'd expect.
Upsampling methods: getting back to full resolution
The decoder needs to increase spatial dimensions, and there are several ways to do this. Each has tradeoffs worth understanding:
Transposed convolutions (also called deconvolutions, though that's technically a misnomer): learned upsampling. The network learns the upsampling weights through backpropagation, giving it maximum flexibility. The downside: transposed convolutions can produce checkerboard artifacts -- visible grid patterns in the output caused by uneven overlap of the output tiles. We used nn.ConvTranspose2d in our U-Net above.
Bilinear interpolation + convolution: fixed upsampling followed by a regular convolution to refine. This avoids checkerboard artifacts entirely since the interpolation is smooth by construction. Many modern architectures prefer this approach:
# Three upsampling approaches compared
x = torch.randn(1, 64, 8, 8)
# 1. Transposed convolution (learned)
up_transposed = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
out1 = up_transposed(x)
print(f"Transposed conv: {x.shape} -> {out1.shape}")
# 2. Bilinear interpolation + conv (fixed + learned)
up_bilinear = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(64, 32, 3, padding=1))
out2 = up_bilinear(x)
print(f"Bilinear + conv: {x.shape} -> {out2.shape}")
# 3. Pixel shuffle (sub-pixel convolution)
# Rearranges (C*r^2, H, W) -> (C, H*r, W*r) where r is upscale factor
up_shuffle = nn.Sequential(
nn.Conv2d(64, 32 * 4, 3, padding=1), # expand channels by r^2=4
nn.PixelShuffle(2)) # rearrange into spatial dims
out3 = up_shuffle(x)
print(f"Pixel shuffle: {x.shape} -> {out3.shape}")
# Parameter comparison
for name, module in [("Transposed", up_transposed), ("Bilinear+conv", up_bilinear),
("Pixel shuffle", up_shuffle)]:
params = sum(p.numel() for p in module.parameters())
print(f" {name:>15s}: {params:,} params")
Pixel shuffle (Shi et al., 2016) is particularly clever -- it uses a regular convolution to produce extra channels, then rearranges those channels into spatial dimensions. So a tensor of shape (C * r^2, H, W) becomes (C, Hr, Wr). This is used in super-resolution and some segmentation architectures. It avoids checkerboard artifacts and is computationally efficient.
Neural style transfer: the unexpected application
This is one of the most beautiful demonstrations of what CNN features actually represent. Neural style transfer (Gatys et al., 2015) takes a content image (your photo) and a style image (a Van Gogh painting), and generates a new image that has the content of your photo rendered in the style of the painting.
The method exploits the feature hierarchy we discussed in episode #45. Recall that early CNN layers detect edges and textures, while deep layers detect objects and scene structure. The insight is that these two levels of representation correspond to style and content, and they can be manipulated independently:
- Content is captured by the raw activations at deep layers. Two images have similar content if a deep layer's activations look similar (same objects, same spatial arrangement).
- Style is captured by the correlations between filter activations across channels, computed as a Gram matrix. Two images have similar style if their textures, brushstrokes, and color patterns are similar -- regardless of what objects are actually depicted.
from torchvision import models
# Extract features at specific layers for style transfer
class StyleContentExtractor(nn.Module):
def __init__(self):
super().__init__()
vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
# Style layers: early layers capture texture/color
# Content layers: deep layers capture structure
self.style_layers = [0, 5, 10, 19, 28] # conv1_1 through conv5_1
self.content_layers = [25] # conv4_2
self.slices = nn.ModuleList()
prev = 0
for idx in sorted(set(self.style_layers + self.content_layers)):
self.slices.append(vgg[prev:idx+1])
prev = idx + 1
# Freeze -- we never train these weights
for param in self.parameters():
param.requires_grad = False
def gram_matrix(self, x):
"""Compute Gram matrix -- correlations between feature channels."""
B, C, H, W = x.shape
features = x.view(B, C, H * W)
gram = torch.bmm(features, features.transpose(1, 2))
return gram / (C * H * W) # normalize
def forward(self, x):
style_grams = []
content_features = []
layer_idx = 0
target_layers = sorted(set(self.style_layers + self.content_layers))
for i, s in enumerate(self.slices):
x = s(x)
current_layer = target_layers[i]
if current_layer in self.style_layers:
style_grams.append(self.gram_matrix(x))
if current_layer in self.content_layers:
content_features.append(x)
return style_grams, content_features
# Demo
extractor = StyleContentExtractor()
img = torch.randn(1, 3, 224, 224)
style_grams, content_feats = extractor(img)
print(f"Style representations (Gram matrices):")
for i, g in enumerate(style_grams):
print(f" Layer {i}: Gram shape = {g.shape}")
print(f"\nContent representation:")
for i, c in enumerate(content_feats):
print(f" Layer {i}: feature shape = {c.shape}")
The Gram matrix G[i,j] captures how often filter i and filter j activate together. If one filter detects horizontal blue strokes and another detects thick impasto texture, their correlation captures "blue impasto" as a style element. By matching the Gram matrices between a generated image and the style image, you transfer these texture-level patterns without transferring the actual content (objects, layout).
The original method is slow -- it optimizes the pixel values of the generated image through hundreds of gradient descent iterations, minimizing a weighted sum of content loss (deep layer activations should match the content image) and style loss (Gram matrices should match the style image). Modern feed-forward style transfer networks (Johnson et al., 2016) train a CNN to perform the transformation in a single forward pass, enabling real-time style transfer on video. The training is slow (you're training a whole network), but inference is just one forward pass -- same as any other CNN.
Practical: feature extraction with pretrained models
For many practical tasks, you don't need full detection or segmentation architectures. You need the features a CNN has learned. Remove the classification head from a pretrained ResNet, and the 512-dimensional vector it produces is an incredibly powerful representation of the image's content:
from torchvision import models, transforms
# Load pretrained ResNet-18, remove the classification head
backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
feature_extractor.eval()
# Standard ImageNet preprocessing
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Extract features (using random data as stand-in for real images)
images = torch.randn(5, 3, 224, 224)
with torch.no_grad():
features = feature_extractor(images).squeeze(-1).squeeze(-1)
print(f"5 images -> feature matrix: {features.shape}") # (5, 512)
# Cosine similarity between image features
cos_sim = torch.nn.functional.cosine_similarity
for i in range(5):
for j in range(i+1, 5):
sim = cos_sim(features[i].unsqueeze(0), features[j].unsqueeze(0))
print(f" Image {i} vs Image {j}: similarity = {sim.item():.3f}")
This 512-dimensional feature vector captures the visual content of the image in a form that's immensely useful for downstream tasks:
- Image similarity search: find similar images in a database by computing cosine distance between feature vectors. This is how reverse image search works -- Google Photos, Pinterest, etc.
- Image clustering: group visually similar images using k-means (episode #22) or DBSCAN (episode #23) on the feature vectors.
- Classical ML on images: extract features from a pretrained CNN and feed them to a random forest (episode #18), gradient boosting (episode #19), or SVM (episode #20). When you have 50 labeled images, you're not training a CNN from scratch -- you're extracting pretrained features and feeding them to a classical model.
# The practical bridge: CNN features + classical ML
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
import numpy as np
# Simulate extracting features for 200 images across 5 classes
np.random.seed(42)
n_samples = 200
n_features = 512
X_features = np.random.randn(n_samples, n_features).astype(np.float32)
y_labels = np.random.randint(0, 5, n_samples)
# Train a random forest on CNN features
rf = RandomForestClassifier(n_estimators=100, random_state=42)
scores = cross_val_score(rf, X_features, y_labels, cv=5)
print(f"Random Forest on CNN features: {scores.mean():.1%} accuracy (5-fold CV)")
print(f"(Random data -> ~20% baseline for 5 classes)")
print(f"\nWith REAL pretrained features and structured data,")
print(f"this approach routinely achieves 90%+ accuracy")
print(f"on datasets too small for end-to-end CNN training")
This pattern -- pretrained CNN as feature extractor feeding into classical ML -- is the practical bridge between everything we covered in Arc 1 (episodes #10-#36) and the deep learning techniques we're learning now. When your dataset has 50 images, training a CNN is out of the question. But extracting 512 features from a pretrained ResNet and feeding them to the random forest we built in episode #18? That works surprisingly well. It's one of the most underrated techniques in practical ML ;-)
Top-K predictions and confidence calibration
Before we wrap up, there's one more practical classification technique worth covering. In episode #46 we always took argmax to get the single best prediction. But in production, you often want the top-K predictions with their confidence scores:
# Top-K predictions with confidence scores
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.eval()
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)
# Top-5 predictions
topk_probs, topk_indices = probs.topk(5, dim=1)
print("Top-5 predictions:")
for i in range(5):
print(f" Class {topk_indices[0, i].item():>4d}: {topk_probs[0, i].item():.4f} "
f"({topk_probs[0, i].item()*100:.1f}%)")
# Top-1 vs Top-5 accuracy concept
print(f"\nTop-1 accuracy: correct if THE top prediction is right")
print(f"Top-5 accuracy: correct if the true class is in the top 5")
print(f"ImageNet benchmark: top-1 ~80%, top-5 ~95% for modern models")
Top-5 accuracy is the standard ImageNet metric -- if the correct class appears anywhere in the model's 5 most confident predictions, it counts as correct. For a 1000-class problem with fine-grained distinctions (hundreds of dog breeds, dozens of bird species), top-5 is a more forgiving and practical metric than top-1. In production systems, showing the user "this might be A, B, or C" with confidence percentages is often more useful than commiting to a single answer.
What we covered
- The vision task spectrum progresses from classification (one label) to detection (bounding boxes) to segmentation (per-pixel labels) to instance segmentation (per-pixel labels + object identity);
- Object detection evolved from R-CNN (2014, slow two-stage) through Fast R-CNN to Faster R-CNN (shared backbone + Region Proposal Network with anchor boxes);
- YOLO collapsed detection into a single forward pass -- grid-based prediction with NMS post-processing. Modern YOLO variants are the practical default for real-time detection;
- Image segmentation uses encoder-decoder architectures where the encoder compresses and the decoder recovers spatial resolution. U-Net introduced skip connections between corresponding encoder and decoder levels;
- Upsampling methods: transposed convolutions (learned but can checkerboard), bilinear interpolation + conv (smooth), pixel shuffle (channel-to-space rearrangement);
- Neural style transfer exploits the CNN feature hierarchy -- deep activations capture content, Gram matrices of shallow activations capture style. Optimize a generated image to match both;
- Pretrained CNNs as feature extractors produce powerful 512-d vectors for similarity search, clustering, or feeding into classical ML models from Arc 1 -- the practical bridge between classical and deep learning;
- Top-K predictions with confidence scores are more practical than single argmax in production systems.
We've now completed our coverage of CNNs -- theory, practice, and applications. CNNs exploit spatial structure in images beautifully, using local filters and weight sharing to build hierarchical representations. But images aren't the only type of structured data. What about text, audio, and time series -- data where the order of elements matters, where context stretches across long sequences, and where each element depends on what came before it? That's a fundamentally different kind of structure, and it requires a fundamentally different kind of architecture -- one with memory ;-)
Exercises
Exercise 1: Build a simple object detection post-processor. Write a function detect_and_filter(model, image, confidence_threshold=0.5, iou_threshold=0.5) that takes a pretrained Faster R-CNN model and an image tensor, runs inference, filters detections by confidence, applies NMS to remove duplicates, and returns cleaned (boxes, scores, labels). Test it on a batch of 3 random images and print the number of detections before and after filtering for each.
Exercise 2: Implement a minimal segmentation pipeline. Build the MiniUNet from this episode, create a small synthetic dataset of 200 images (64x64, 3 channels) with random circles as "foreground" and the rest as "background" (2 classes), train the U-Net for 30 epochs with cross-entropy loss, and report the per-pixel accuracy on a held-out set of 50 images. Hint: use torch.zeros for background and draw circles by checking (x - cx)^2 + (y - cy)^2 < r^2.
Exercise 3: Build a feature extractor comparison. Extract 512-d features from 100 synthetic images using two different pretrained backbones (ResNet-18 and ResNet-50). Compute the cosine similarity matrix for each (100x100 matrix), then compare: are images that are "similar" under ResNet-18 also "similar" under ResNet-50? Compute the correlation between the two similarity matrices. What does this tell you about the universality of pretrained features across architectures?