Learn AI Series (#60) - Training Large Language Models
What will I learn
- You will learn what training data goes into LLMs -- Common Crawl, books, code, conversations;
- data cleaning and deduplication at scale -- why data quality trumps quantity;
- distributed training strategies: data parallelism, model parallelism, pipeline parallelism;
- mixed precision training and gradient accumulation -- practical tricks that enable scale;
- the compute budget: FLOPs, tokens, and Chinchilla scaling laws;
- what it actually costs to train a frontier LLM.
Requirements
- A working modern computer running macOS, Windows or Ubuntu;
- An installed Python 3(.11+) distribution;
- The ambition to learn AI and machine learning.
Difficulty
- Beginner
Curriculum (of the Learn AI Series):
- Learn AI Series (#1) - What Machine Learning Actually Is
- Learn AI Series (#2) - Setting Up Your AI Workbench - Python and NumPy
- Learn AI Series (#3) - Your Data Is Just Numbers - How Machines See the World
- Learn AI Series (#4) - Your First Prediction - No Math, Just Intuition
- Learn AI Series (#5) - Patterns in Data - What "Learning" Actually Looks Like
- Learn AI Series (#6) - From Intuition to Math - Why We Need Formulas
- Learn AI Series (#7) - The Training Loop - See It Work Step by Step
- Learn AI Series (#8) - The Math You Actually Need (Part 1) - Linear Algebra
- Learn AI Series (#9) - The Math You Actually Need (Part 2) - Calculus and Probability
- Learn AI Series (#10) - Your First ML Model - Linear Regression From Scratch
- Learn AI Series (#11) - Making Linear Regression Real
- Learn AI Series (#12) - Classification - Logistic Regression From Scratch
- Learn AI Series (#13) - Evaluation - How to Know If Your Model Actually Works
- Learn AI Series (#14) - Data Preparation - The 80% Nobody Talks About
- Learn AI Series (#15) - Feature Engineering and Selection
- Learn AI Series (#16) - Scikit-Learn - The Standard Library of ML
- Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
- Learn AI Series (#18) - Random Forests - Wisdom of Crowds
- Learn AI Series (#19) - Gradient Boosting - The Kaggle Champion
- Learn AI Series (#20) - Support Vector Machines - Drawing the Perfect Boundary
- Learn AI Series (#21) - Mini Project - Predicting Crypto Market Regimes
- Learn AI Series (#22) - K-Means Clustering - Finding Groups
- Learn AI Series (#23) - Advanced Clustering - Beyond K-Means
- Learn AI Series (#24) - Dimensionality Reduction - PCA
- Learn AI Series (#25) - Advanced Dimensionality Reduction - t-SNE and UMAP
- Learn AI Series (#26) - Anomaly Detection - Finding What Doesn't Belong
- Learn AI Series (#27) - Recommendation Systems - "Users Like You Also Liked..."
- Learn AI Series (#28) - Time Series Fundamentals - When Order Matters
- Learn AI Series (#29) - Time Series Forecasting - Predicting What Comes Next
- Learn AI Series (#30) - Natural Language Processing - Text as Data
- Learn AI Series (#31) - Word Embeddings - Meaning in Numbers
- Learn AI Series (#32) - Bayesian Methods - Thinking in Probabilities
- Learn AI Series (#33) - Ensemble Methods Deep Dive - Stacking and Blending
- Learn AI Series (#34) - ML Engineering - From Notebook to Production
- Learn AI Series (#35) - Data Ethics and Bias in ML
- Learn AI Series (#36) - Mini Project - Complete ML Pipeline
- Learn AI Series (#37) - The Perceptron - Where It All Started
- Learn AI Series (#38) - Neural Networks From Scratch - Forward Pass
- Learn AI Series (#39) - Neural Networks From Scratch - Backpropagation
- Learn AI Series (#40) - Training Neural Networks - Practical Challenges
- Learn AI Series (#41) - Optimization Algorithms - SGD, Momentum, Adam
- Learn AI Series (#42) - PyTorch Fundamentals - Tensors and Autograd
- Learn AI Series (#43) - PyTorch Data and Training
- Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks
- Learn AI Series (#45) - Convolutional Neural Networks - Theory
- Learn AI Series (#46) - CNNs in Practice - Classic to Modern Architectures
- Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer
- Learn AI Series (#48) - Recurrent Neural Networks - Sequences
- Learn AI Series (#49) - LSTM and GRU - Solving the Memory Problem
- Learn AI Series (#50) - Sequence-to-Sequence Models
- Learn AI Series (#51) - Attention Mechanisms
- Learn AI Series (#52) - The Transformer Architecture (Part 1)
- Learn AI Series (#53) - The Transformer Architecture (Part 2)
- Learn AI Series (#54) - Vision Transformers
- Learn AI Series (#55) - Generative Adversarial Networks
- Learn AI Series (#56) - Mini Project - Building a Transformer From Scratch
- Learn AI Series (#57) - Language Modeling - Predicting the Next Word
- Learn AI Series (#58) - GPT Architecture - Decoder-Only Transformers
- Learn AI Series (#59) - BERT and Encoder Models
- Learn AI Series (#60) - Training Large Language Models (this post)
Learn AI Series (#60) - Training Large Language Models
Solutions to Episode #59 Exercises
Exercise 1: Build a complete BERT-style masked language model and train it on a small text corpus.
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import math
class BERTEmbedding(nn.Module):
def __init__(self, vocab_size, d_model, max_len, n_segments=2):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.seg_emb = nn.Embedding(n_segments, d_model)
self.norm = nn.LayerNorm(d_model)
self.drop = nn.Dropout(0.1)
def forward(self, tokens, segments):
pos = torch.arange(tokens.size(1), device=tokens.device)
x = self.tok_emb(tokens) + self.pos_emb(pos) + self.seg_emb(segments)
return self.drop(self.norm(x))
class BERTModel(nn.Module):
def __init__(self, vocab_size, d_model=128, n_heads=4,
n_layers=4, d_ff=512, max_len=128):
super().__init__()
self.embedding = BERTEmbedding(vocab_size, d_model, max_len)
layer = nn.TransformerEncoderLayer(
d_model, n_heads, d_ff, batch_first=True, activation='gelu')
self.encoder = nn.TransformerEncoder(layer, n_layers)
self.mlm_head = nn.Sequential(
nn.Linear(d_model, d_model), nn.GELU(),
nn.LayerNorm(d_model),
nn.Linear(d_model, vocab_size))
def forward(self, tokens, segments):
x = self.embedding(tokens, segments)
x = self.encoder(x)
return self.mlm_head(x)
# Build a simple word-level vocabulary
corpus = """the cat sat on the mat the dog ran in the park
a bird flew over the fence and landed on the tree
the big cat chased the small mouse around the house
the dog and the cat played together in the yard all day
a fish swam in the pond while the bird watched from above"""
words = corpus.split()
vocab_words = sorted(set(words)) + ["[MASK]", "[PAD]"]
w2i = {w: i for i, w in enumerate(vocab_words)}
i2w = {i: w for w, i in w2i.items()}
vocab_size = len(vocab_words)
mask_id = w2i["[MASK]"]
def apply_bert_masking(tokens, mask_prob=0.15):
masked = tokens.clone()
labels = torch.full_like(tokens, -100)
for i in range(len(tokens)):
if tokens[i] == w2i["[PAD]"]:
continue
if random.random() < mask_prob:
labels[i] = tokens[i]
r = random.random()
if r < 0.8:
masked[i] = mask_id
elif r < 0.9:
masked[i] = random.randint(0, vocab_size - 1)
return masked, labels
# Create training sequences of length 32
seq_len = 32
data_ids = [w2i[w] for w in words]
seqs = []
for i in range(0, len(data_ids) - seq_len, seq_len // 2):
seqs.append(data_ids[i:i+seq_len])
seqs_t = torch.tensor(seqs, dtype=torch.long)
model = BERTModel(vocab_size, d_model=128, n_heads=4, n_layers=4,
d_ff=512, max_len=128)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
for epoch in range(50):
total_loss = 0
for seq in seqs_t:
masked, labels = apply_bert_masking(seq)
segs = torch.zeros_like(masked)
logits = model(masked.unsqueeze(0), segs.unsqueeze(0))
loss = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1))
opt.zero_grad()
loss.backward()
opt.step()
total_loss += loss.item()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}: loss={total_loss/len(seqs_t):.3f}")
# Test: mask a word and check predictions
test = [w2i[w] for w in "the cat sat on the".split()]
test += [w2i["[PAD]"]] * (seq_len - len(test))
test_t = torch.tensor(test)
test_m = test_t.clone()
test_m[1] = mask_id # mask "cat"
model.eval()
with torch.no_grad():
logits = model(test_m.unsqueeze(0), torch.zeros(1, seq_len, dtype=torch.long))
top5 = logits[0, 1].topk(5)
print(f"\nMasked 'cat' at position 1. Top-5 predictions:")
for score, idx in zip(top5.values, top5.indices):
print(f" {i2w[idx.item()]:>10}: {score.item():.2f}")
On this small corpus, the model should predict words that appear in similar contexts to "cat" (like "dog", "bird", "big") because those are the tokens that occur in the same "the ___ sat/ran/chased" patterns. The top-5 won't be perfect -- the corpus is tiny -- but you should see semantically plausible guesses rather than random words.
Exercise 2: Compare bidirectional vs causal representations for fill-in-the-blank.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SmallEncoder(nn.Module):
def __init__(self, vocab_size, d_model=128, n_heads=4,
n_layers=4, d_ff=512, max_len=64, use_causal=False):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
layer = nn.TransformerEncoderLayer(
d_model, n_heads, d_ff, batch_first=True, activation='gelu')
self.encoder = nn.TransformerEncoder(layer, n_layers)
self.head = nn.Linear(d_model, vocab_size)
self.use_causal = use_causal
self.max_len = max_len
def forward(self, x):
T = x.size(1)
pos = torch.arange(T, device=x.device)
h = self.tok_emb(x) + self.pos_emb(pos)
if self.use_causal:
mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device)
h = self.encoder(h, mask=mask)
else:
h = self.encoder(h)
return self.head(h)
# Use same corpus/vocab from Exercise 1
bidir = SmallEncoder(vocab_size, use_causal=False)
causal = SmallEncoder(vocab_size, use_causal=True)
# Train both with MLM on same data, same epochs
for model_name, model_obj in [("Bidirectional", bidir), ("Causal", causal)]:
opt = torch.optim.AdamW(model_obj.parameters(), lr=3e-4)
for epoch in range(50):
for seq in seqs_t:
masked, labels = apply_bert_masking(seq)
logits = model_obj(masked.unsqueeze(0))
loss = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1))
opt.zero_grad()
loss.backward()
opt.step()
# Compare on 20 test sentences
test_sents = [
"the cat sat on the mat".split(),
"the dog ran in the park".split(),
"a bird flew over the fence".split(),
]
for model_name, model_obj in [("Bidir", bidir), ("Causal", causal)]:
model_obj.eval()
correct, total = 0, 0
for sent in test_sents:
for mask_pos in range(len(sent)):
ids = [w2i.get(w, 0) for w in sent]
ids += [w2i["[PAD]"]] * (seq_len - len(ids))
ids_t = torch.tensor(ids)
ids_m = ids_t.clone()
ids_m[mask_pos] = mask_id
with torch.no_grad():
logits = model_obj(ids_m.unsqueeze(0))
top5 = logits[0, mask_pos].topk(5).indices.tolist()
if w2i[sent[mask_pos]] in top5:
correct += 1
total += 1
print(f"{model_name}: {correct}/{total} correct in top-5 "
f"({100*correct/total:.0f}%)")
The bidirectional model should outperform the causal model at this fill-in-the-blank task, because it can use context from both sides of the masked position. The causal model at early positions (position 0 or 1) has almost no context to work with, while the bidirectional model always uses the full sentence.
Exercise 3: Fine-tuning for classification with frozen vs unfrozen encoder.
import torch
import torch.nn as nn
import torch.nn.functional as F
# Synthetic classification data
pos_words = {"good", "great", "excellent", "wonderful", "fantastic"}
neg_words = {"bad", "terrible", "awful", "horrible", "worst"}
# Generate sentences (use vocab from Exercise 1 + sentiment words)
all_words = list(set(words)) + list(pos_words) + list(neg_words) + ["[CLS]", "[MASK]", "[PAD]"]
all_words = sorted(set(all_words))
w2i_clf = {w: i for i, w in enumerate(all_words)}
vocab_clf = len(all_words)
def make_data(n=200):
X, y = [], []
templates = ["the {} cat sat", "a {} day in the park",
"the {} dog ran", "a {} bird flew"]
for _ in range(n):
if random.random() < 0.5:
word = random.choice(list(pos_words))
label = 1
else:
word = random.choice(list(neg_words))
label = 0
tmpl = random.choice(templates)
sent = ["[CLS]"] + tmpl.format(word).split()
ids = [w2i_clf.get(w, 0) for w in sent]
ids += [w2i_clf["[PAD]"]] * (16 - len(ids))
X.append(ids[:16])
y.append(label)
return torch.tensor(X), torch.tensor(y)
X_train, y_train = make_data(200)
X_test, y_test = make_data(50)
# Pre-trained encoder (reuse structure, init random for demo)
encoder = SmallEncoder(vocab_clf, use_causal=False)
class Classifier(nn.Module):
def __init__(self, encoder, d_model=128, n_classes=2):
super().__init__()
self.encoder = encoder
self.head = nn.Linear(d_model, n_classes)
def forward(self, x):
h = self.encoder.tok_emb(x) + self.encoder.pos_emb(
torch.arange(x.size(1), device=x.device))
h = self.encoder.encoder(h)
cls = h[:, 0, :]
return self.head(cls)
# Full fine-tuning (encoder trainable)
clf_full = Classifier(SmallEncoder(vocab_clf), n_classes=2)
opt = torch.optim.AdamW(clf_full.parameters(), lr=1e-3)
print("Full fine-tuning:")
for epoch in range(20):
logits = clf_full(X_train)
loss = F.cross_entropy(logits, y_train)
opt.zero_grad()
loss.backward()
opt.step()
with torch.no_grad():
acc = (clf_full(X_test).argmax(1) == y_test).float().mean()
print(f" Epoch {epoch+1}: loss={loss.item():.3f}, acc={acc:.3f}")
# Frozen encoder (only head trainable)
clf_frozen = Classifier(SmallEncoder(vocab_clf), n_classes=2)
for p in clf_frozen.encoder.parameters():
p.requires_grad = False
opt2 = torch.optim.AdamW(
filter(lambda p: p.requires_grad, clf_frozen.parameters()), lr=1e-3)
print("\nFrozen encoder (head only):")
for epoch in range(20):
logits = clf_frozen(X_train)
loss = F.cross_entropy(logits, y_train)
opt2.zero_grad()
loss.backward()
opt2.step()
with torch.no_grad():
acc = (clf_frozen(X_test).argmax(1) == y_test).float().mean()
print(f" Epoch {epoch+1}: loss={loss.item():.3f}, acc={acc:.3f}")
Full fine-tuning typically reaches higher accuracy faster because the encoder adapts its representations to the sentiment classification task. The frozen encoder relies entirely on whatever representations it learned during pre-training, which might not align well with the classification objective. On a real pre-trained BERT, the gap is smaller (the pre-trained representations are already excellent), but the full fine-tuning approach still converges faster.
On to today's episode
Here we go! In episode #58 we dissected the GPT architecture -- the decoder-only transformer that turned next-token prediction into the dominant paradigm. In episode #59 we explored the other side of the coin with BERT and encoder models, the pre-training + fine-tuning paradigm, and how bidirectional attention beats causal attention for understanding tasks. We now have a solid grasp of what these models look like on the inside.
But here's the question that separates theory from practice: how do you actually train one of these things? Not a toy model on Shakespeare text -- a real model with 70 billion parameters or more, trained on trillions of tokens, across thousands of GPUs running for weeks.
The answer involves industrial-scale data engineering, distributed computing across GPU clusters the size of small data centers, and enough electricity to power a small town. This episode covers the full pipeline: what goes in, how training works at scale, and what it costs. It's the engineering story behind the architectures we've been studying, and in my opinion one of the most fascinating aspects of modern AI -- because it reveals just how much brute-force infrastructure goes into making a model that can write poetry ;-)
The data: what LLMs eat
A language model is only as good as its training data. We discussed this back in episode #14 (data preparation) and episode #35 (data ethics), but the scale here is completely different. The typical training corpus for a frontier model in 2026 draws from several sources, mixed in carefully calibrated ratios:
Common Crawl (~60% of most training mixes): raw web scrapes, petabytes of text from billions of web pages. The internet's text, unfiltered. This is the volume play -- it provides breadth across topics, languages, and writing styles.
Books (~10-15%): Project Gutenberg, digitized books, and commercial book datasets. Dense, well-written text with long-range coherence. Books teach the model sustained narrative, argument structure, and factual knowledge that web pages often lack.
Code (~5-15%): GitHub repositories, Stack Overflow, documentation. Code is highly structured and forces the model to learn precise, syntactically exact generation. Including code significantly improves reasoning abilities even on non-code tasks -- likely because code requires logical step-by-step thinking. Which is a fascinating result if you think about it -- training on Python functions makes the model better at answering history questions.
Wikipedia (~3-5%): encyclopedic knowledge, factual writing, structured information. High quality per token but limited volume.
Scientific papers (~1-3%): ArXiv, PubMed. Technical writing, mathematical reasoning, domain expertise.
Conversations (~1-5%): Reddit threads, forums, dialogue datasets. Social language patterns, question-answer format, informal communication.
import numpy as np
# Typical data mix for a frontier LLM (approximate percentages)
sources = {
"Common Crawl (web)": 60,
"Books": 12,
"Code (GitHub/SO)": 10,
"Wikipedia": 4,
"Scientific papers": 2,
"Conversations/forums": 3,
"Curated datasets": 5,
"Other (news, legal)": 4,
}
total_tokens_T = 2.0 # 2 trillion tokens (LLaMA-2 scale)
print(f"{'Source':<25} {'Mix %':>6} {'Tokens':>14}")
print("-" * 48)
for source, pct in sources.items():
tokens = total_tokens_T * pct / 100
print(f"{source:<25} {pct:>5}% {tokens:>10.1f}T tokens")
print(f"\n{'Total':<25} {sum(sources.values()):>5}% {total_tokens_T:>10.1f}T tokens")
# Compare corpus sizes across model generations
models = [
("GPT-2 (2019)", 40, "WebText"),
("GPT-3 (2020)", 300, "CommonCrawl + WebText + Books + Wiki"),
("LLaMA-2 (2023)", 2000, "Mixed (CC + Books + Code + Wiki + ArXiv)"),
("LLaMA-3 (2024)", 15000, "Mixed (expanded, multilingual)"),
]
print(f"\n{'Model':<20} {'Tokens (B)':>12} {'Source':>40}")
print("-" * 76)
for name, tokens_b, src in models:
print(f"{name:<20} {tokens_b:>10,}B {src:>40}")
The jump from 300 billion tokens (GPT-3) to 15 trillion tokens (LLaMA-3) in just four years tells you how fast the data requirements have scaled. And these are tokens, not words -- with BPE tokenization (episode #57), each token is roughly 3/4 of a word, so 15 trillion tokens is about 11 trillion words. For context, the entire English Wikipedia is roughly 4 billion words. LLaMA-3's training data is about 2,750 times the size of Wikipedia. Wowzers.
Data quality: the 80% nobody talks about
Raw Common Crawl is awful. It's full of navigation menus, cookie consent banners, SEO spam, duplicate pages, adult content, and machine-generated text. Using it directly would produce a model that generates "Click here to accept cookies. Subscribe to our newsletter!" in stead of useful text.
Data cleaning is the unsexy but critical step. A typical pipeline looks something like this:
import hashlib
import re
def clean_document(text, min_length=200, max_line_ratio=0.3,
min_unique_ratio=0.2):
"""Quality filtering for web-crawled documents.
Returns cleaned text or None if rejected."""
# 1. Length check -- short pages are usually menus/footers
if len(text) < min_length:
return None
# 2. Line ratio -- too many linebreaks = navigation/list pages
n_lines = text.count('\n')
if n_lines > 0 and n_lines / len(text) > max_line_ratio:
return None
# 3. Repetition check -- spam/template pages repeat themselves
words = text.split()
if len(words) == 0:
return None
unique_ratio = len(set(words)) / len(words)
if unique_ratio < min_unique_ratio:
return None
# 4. Boilerplate removal
boilerplate = [
"cookie", "privacy policy", "terms of service",
"subscribe to our", "click here", "all rights reserved",
]
lower = text.lower()
boilerplate_count = sum(1 for b in boilerplate if b in lower)
if boilerplate_count >= 3:
return None
# 5. Language detection (simplified -- real pipelines use fasttext)
ascii_ratio = sum(1 for c in text if ord(c) < 128) / len(text)
# For English-focused models, reject if too little ASCII
# (real pipelines do proper language ID)
return text
# Example: batch processing
docs = [
"The transformer architecture revolutionized NLP by enabling parallel " * 20,
"Click here | Home | About | Contact | Privacy | Terms | Cookie Policy",
"Buy cheap laptops!!! Best deals!! Click now!! Limited offer!!!",
"Natural language processing has made significant advances in recent years. "
"The introduction of attention mechanisms in 2014 allowed models to focus "
"on relevant parts of the input sequence, leading to dramatic improvements "
"in machine translation, text summarization, and question answering.",
]
for i, doc in enumerate(docs):
result = clean_document(doc)
status = "KEPT" if result else "REJECTED"
print(f"Doc {i}: {status} (first 60 chars: '{doc[:60]}...')")
Deduplication
Deduplication is particularly important and deserves its own section. The web is full of duplicated content: the same news article syndicated across 50 sites, copied documentation, scraped blog posts. Training on duplicates wastes compute and can cause the model to memorize and regurgitate specific passages (which is both a privacy concern and a quality problem).
There are two levels of deduplication:
Exact deduplication: hash each document and remove exact matches. Fast but misses near-duplicates (same article with a different date header, or with one paragraph added).
Near-duplicate detection: use MinHash (locality-sensitive hashing) to identify documents that share most of their content even if they're not byte-for-byte identical. This is computationally expensive at scale but essential.
import hashlib
from collections import defaultdict
def compute_shingles(text, k=5):
"""Create k-character shingles (overlapping substrings)."""
words = text.lower().split()
shingles = set()
for i in range(len(words) - k + 1):
shingle = ' '.join(words[i:i+k])
shingles.add(shingle)
return shingles
def jaccard_similarity(set_a, set_b):
"""Jaccard similarity between two sets."""
if not set_a or not set_b:
return 0.0
intersection = len(set_a & set_b)
union = len(set_a | set_b)
return intersection / union
# Demonstrate near-duplicate detection
docs = [
"The cat sat on the mat and looked at the window",
"The cat sat on the mat and looked at the door", # near-duplicate
"A completely different document about neural networks",
"The cat sat on the mat and looked at the window today", # near-dup
]
print("Jaccard similarity matrix:")
for i, d1 in enumerate(docs):
s1 = compute_shingles(d1, k=3)
sims = []
for j, d2 in enumerate(docs):
s2 = compute_shingles(d2, k=3)
sims.append(jaccard_similarity(s1, s2))
print(f" Doc {i}: {['%.2f' % s for s in sims]}")
print("\nNear-duplicates (Jaccard > 0.5):")
for i in range(len(docs)):
for j in range(i+1, len(docs)):
s1 = compute_shingles(docs[i], k=3)
s2 = compute_shingles(docs[j], k=3)
sim = jaccard_similarity(s1, s2)
if sim > 0.5:
print(f" Doc {i} <-> Doc {j}: {sim:.2f}")
At the scale of Common Crawl (hundreds of billions of documents), even MinHash needs careful engineering. The standard approach is to compute MinHash signatures (much smaller than full shingle sets), store them in LSH (Locality-Sensitive Hashing) buckets, and only compare documents that land in the same bucket. This reduces the comparison space from O(N^2) to something tractable. Having said that, deduplication at this scale is still one of the most compute-intensive parts of the data pipeline -- not the model training itself.
Data mixtures: the secret sauce
The ratio of different data sources matters enormously. Meta's LLaMA paper showed that changing the mix from 67% web / 15% books / 5% code to 60% web / 10% books / 10% code significantly changed downstream performance across benchmarks.
# Impact of data mix on model behavior (conceptual)
mixes = {
"Web-heavy (80/10/5/5)": {
"writing_style": "Casual, blog-like, sometimes incoherent",
"reasoning": "Moderate",
"code": "Poor",
"knowledge": "Broad but shallow",
},
"Balanced (60/15/10/5)": {
"writing_style": "Natural, versatile",
"reasoning": "Good",
"code": "Good",
"knowledge": "Broad and moderate depth",
},
"Code-heavy (40/10/35/5)": {
"writing_style": "Technical, sometimes overly structured",
"reasoning": "Excellent (code forces logic)",
"code": "Excellent",
"knowledge": "Narrower, tech-focused",
},
}
print(f"{'Mix':<28} {'Style':<35} {'Reasoning':<12} {'Code':<12}")
print("-" * 90)
for mix_name, traits in mixes.items():
print(f"{mix_name:<28} {traits['writing_style']:<35} "
f"{traits['reasoning']:<12} {traits['code']:<12}")
Modern approaches also use data quality classifiers: train a small model to distinguish "high quality" text (Wikipedia, published books) from "low quality" text (random web pages), then use this classifier to weight or filter the training data. This effectively upsamples quality without reducing volume. The Phi series of models from Microsoft took this to an extreme -- training on "textbook quality" data produced surprisingly capable small models, proving that data quality can partially compensate for model size.
Distributed training: the engineering challenge
Here's where we leave the world of model.fit() on a single GPU and enter the world of cluster-scale engineering. A 70B parameter model in float32 requires 280GB just to store the parameters. Add optimizer states (Adam stores two extra values per parameter -- the first and second moment estimates we covered in episode #41) and you need roughly 840GB. A single H100 GPU has 80GB of memory. You need at least 11 GPUs just to hold the model, before processing a single token.
Three strategies for distributing training across many GPUs:
Data parallelism
The simplest approach: copy the entire model to each GPU, split the batch across GPUs, compute gradients independently, then synchronize (average) gradients before updating weights.
# Conceptual data parallelism
# Each GPU processes a different mini-batch
# All GPUs have a complete copy of the model
# PyTorch DistributedDataParallel (the standard approach):
# model = nn.parallel.DistributedDataParallel(model)
# or with Fully Sharded Data Parallelism:
# model = FSDP(model)
# The training loop looks identical to single-GPU:
# for batch in dataloader:
# loss = model(batch)
# loss.backward()
# optimizer.step() # DDP handles gradient sync automatically
# What happens under the hood:
# 1. Each GPU gets a different slice of the batch
# 2. Each GPU computes forward + backward independently
# 3. all_reduce(gradients) -- average gradients across all GPUs
# 4. optimizer.step() -- identical update on all GPUs (they stay in sync)
# Effective batch size = per_gpu_batch * n_gpus
# 8 GPUs x 32 per GPU = effective batch size of 256
n_gpus = 8
batch_per_gpu = 32
effective_batch = n_gpus * batch_per_gpu
print(f"Data parallelism with {n_gpus} GPUs:")
print(f" Per-GPU batch: {batch_per_gpu}")
print(f" Effective batch: {effective_batch}")
print(f" Speedup: ~{n_gpus}x (near-linear if communication is fast)")
print(f" Limitation: model must fit on ONE GPU")
Data parallelism scales the batch size. But it doesn't help if the model doesn't fit on one GPU -- you need model parallelism for that.
Tensor parallelism (model parallelism)
Split individual layers across GPUs. For a linear layer with a 12,288 x 49,152 weight matrix, you can split the output dimension across 4 GPUs -- each computes a quarter of the output, then they concatenate the results.
This works within a single transformer layer. Each attention head (or group of heads) can run on a different GPU, and the feed-forward layer can be split across the column or row dimension.
# Conceptual tensor parallelism for a linear layer
# Split weight matrix across GPUs along one dimension
import torch
def tensor_parallel_linear(x, weight_shards, n_gpus=4):
"""Simulate tensor parallelism for a linear layer.
weight is split into n_gpus shards along the output dimension."""
outputs = []
for shard in weight_shards:
# Each GPU computes its slice
out_shard = x @ shard.T
outputs.append(out_shard)
# All-gather: concatenate results from all GPUs
return torch.cat(outputs, dim=-1)
# Example: 768-dim input, 3072-dim output, split across 4 GPUs
d_in, d_out = 768, 3072
weight = torch.randn(d_out, d_in)
# Split into 4 shards of 768 output dims each
shards = [weight[i*768:(i+1)*768, :] for i in range(4)]
x = torch.randn(2, 10, d_in) # batch=2, seq=10
result = tensor_parallel_linear(x, shards)
print(f"Input: {x.shape}")
print(f"Each shard: {shards[0].shape}")
print(f"Output: {result.shape}")
print(f"Memory per GPU: {shards[0].numel() * 4 / 1e6:.1f}MB "
f"(vs {weight.numel() * 4 / 1e6:.1f}MB total)")
Pipeline parallelism
Split the model by layers. GPU 0 runs layers 1-8, GPU 1 runs layers 9-16, GPU 2 runs layers 17-24, and so on. Data flows through the pipeline: GPU 0 processes the input, sends activations to GPU 1, which sends to GPU 2.
The naive approach has terrible utilization -- while GPU 2 processes batch 1, GPUs 0 and 1 sit idle. Micro-batching fixes this: split each batch into small micro-batches that flow through the pipeline in sequence, keeping all GPUs busy most of the time.
# Pipeline parallelism simulation
def simulate_pipeline(n_stages, n_microbatches, time_per_stage=1):
"""Visualize pipeline schedule.
Shows which micro-batch each stage processes at each timestep."""
total_time = n_stages + n_microbatches - 1
print(f"Pipeline: {n_stages} stages, {n_microbatches} micro-batches")
print(f"{'Time':>6}", end="")
for s in range(n_stages):
print(f" GPU_{s}", end="")
print()
active_slots = 0
total_slots = 0
for t in range(total_time):
print(f"{t:>6}", end="")
for s in range(n_stages):
mb = t - s # which micro-batch this stage processes
total_slots += 1
if 0 <= mb < n_microbatches:
print(f" mb_{mb:>1} ", end="")
active_slots += 1
else:
print(f" idle", end="")
print()
util = active_slots / total_slots
print(f"\nUtilization: {util:.1%}")
print(f"Bubble overhead: {1 - util:.1%}")
# 4 stages, 8 micro-batches
simulate_pipeline(n_stages=4, n_microbatches=8)
print()
# Compare: fewer micro-batches = worse utilization
simulate_pipeline(n_stages=4, n_microbatches=2)
More micro-batches means higher GPU utilization, but also more inter-GPU communication and higher memory usage (you need to store intermediate activations for all in-flight micro-batches). It's an engineering tradeoff, and getting it right is a major part of what frameworks like Megatron-LM (NVIDIA) and DeepSpeed (Microsoft) do.
Real-world: all three combined
Production training runs use all three strategies simultaneously. LLaMA-2 70B was trained on 2048 A100 GPUs using a combination of data, tensor, and pipeline parallelism. The communication between GPUs is as important as the computation. GPUs communicate through NVLink (within a node: 900 GB/s for H100), InfiniBand (between nodes: 400 Gb/s), and the collective operations (all-reduce, all-gather) must be carefully overlapped with computation to avoid idle time.
# Memory budget for training a 70B model
import math
params_B = 70 # billion parameters
# Model weights (fp16 for forward/backward, fp32 master copy)
weights_fp16_GB = params_B * 2 / 1 # 2 bytes per param in fp16 = 140 GB
weights_fp32_GB = params_B * 4 / 1 # 4 bytes per param in fp32 = 280 GB
# Adam optimizer states (fp32)
adam_m_GB = params_B * 4 / 1 # first moment: 280 GB
adam_v_GB = params_B * 4 / 1 # second moment: 280 GB
# Gradients (fp16)
grads_GB = params_B * 2 / 1 # 140 GB
total_GB = weights_fp32_GB + adam_m_GB + adam_v_GB + grads_GB
h100_mem = 80 # GB per H100
min_gpus = math.ceil(total_GB / h100_mem)
print(f"70B parameter model memory budget:")
print(f" FP32 master weights: {weights_fp32_GB:>6.0f} GB")
print(f" Adam first moment: {adam_m_GB:>6.0f} GB")
print(f" Adam second moment: {adam_v_GB:>6.0f} GB")
print(f" FP16 gradients: {grads_GB:>6.0f} GB")
print(f" Total (no activations): {total_GB:>6.0f} GB")
print(f" H100 memory: {h100_mem:>6} GB")
print(f" Minimum GPUs needed: {min_gpus}")
print(f"\n (Activations add another 100-300+ GB depending on batch/seq)")
print(f" Real deployments use 2048+ GPUs for parallelism + redundancy")
Mixed precision training
Two practical tricks that enable training at scale. The first is mixed precision -- using float16 or bfloat16 in stead of float32 for most of the computation. Half the memory, roughly double the throughput on modern GPUs (which have dedicated tensor cores for fp16/bf16 operations).
import torch
import torch.nn as nn
import torch.nn.functional as F
# PyTorch automatic mixed precision
# The standard pattern for mixed-precision training:
model = nn.Linear(1024, 1024) # any model
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = torch.amp.GradScaler()
# Training step with AMP
x = torch.randn(32, 1024)
target = torch.randn(32, 1024)
optimizer.zero_grad()
# Forward pass in fp16/bf16 (autocast handles the conversion)
with torch.amp.autocast(device_type='cpu', dtype=torch.bfloat16):
output = model(x)
loss = F.mse_loss(output, target)
# Backward pass (scaler prevents gradient underflow in fp16)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
print(f"Loss dtype during autocast: {loss.dtype}")
print(f"Model weight dtype: {model.weight.dtype}")
print(f" -> Weights stay fp32 (master copy for optimizer)")
print(f" -> Forward/backward run in bf16 (speed + memory)")
The trick: keep a float32 master copy of the weights for the optimizer update (preserves precision where it matters), but run forward and backward passes in float16/bfloat16 (speed and memory where precision matters less). The gradient scaler prevents underflow in fp16 gradients -- tiny gradients that would round to zero in fp16 are scaled up before the backward pass and scaled back down before the optimizer step.
bfloat16 is preferred over float16 for LLM training. It has the same exponent range as float32 (8 exponent bits) but less mantissa precision (7 bits vs 23). This means it can represent the same range of magnitudes as fp32 -- you just lose some precision in the last decimal places. Float16 has only 5 exponent bits, so it overflows more easily, which is why you need loss scaling with fp16 but often not with bf16.
Gradient accumulation
The second practical trick: when the batch size that fits in GPU memory is too small for stable training, accumulate gradients across multiple forward-backward passes before updating:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Gradient accumulation: simulate large batches without more memory
model = nn.Linear(512, 10)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
actual_batch = 8 # what fits in GPU memory
accumulation_steps = 16 # accumulate this many batches
effective_batch = actual_batch * accumulation_steps # = 128
print(f"Actual batch per step: {actual_batch}")
print(f"Accumulation steps: {accumulation_steps}")
print(f"Effective batch size: {effective_batch}")
optimizer.zero_grad()
for step in range(accumulation_steps):
x = torch.randn(actual_batch, 512)
target = torch.randint(0, 10, (actual_batch,))
logits = model(x)
loss = F.cross_entropy(logits, target)
# Scale loss by accumulation steps so the average gradient
# matches what you'd get from a single large batch
(loss / accumulation_steps).backward()
# Only update weights after accumulating all gradients
optimizer.step()
optimizer.zero_grad()
print(f"\nGradients accumulated from {accumulation_steps} forward passes")
print(f"Single optimizer.step() with effective batch of {effective_batch}")
This gives an effective batch size of actual_batch x accumulation_steps without requiring more GPU memory. The tradeoff: training takes proportionally longer per weight update because you're doing multiple forward-backward passes before each step. But for LLM training where you need effective batch sizes of 2048+ sequences, this is essential -- you simply cannot fit that many sequences on a single GPU (or even a single node of 8 GPUs).
The compute budget: what training costs
Training LLMs is measured in FLOPs (floating point operations). A rough estimate for training a transformer is C = 6 x N x D where N is the number of parameters and D is the number of tokens processed. The factor 6 accounts for the forward pass (~2N FLOPs per token) and the backward pass (~4N FLOPs per token, since backprop is roughly 2x the cost of the forward pass).
import math
def estimate_training_cost(n_params_B, n_tokens_T,
gpu_tflops=1000, gpu_utilization=0.5,
gpu_cost_per_hour=3.0):
"""Estimate training time and cost for an LLM.
n_params_B: billions of parameters
n_tokens_T: trillions of tokens
gpu_tflops: peak TFLOPS of the GPU (H100 ~ 1000 bf16)
gpu_utilization: fraction of peak achieved (typically 40-55%)
"""
# Total FLOPs = 6 * N * D
N = n_params_B * 1e9
D = n_tokens_T * 1e12
total_flops = 6 * N * D
# Effective FLOPS per GPU
effective_tflops = gpu_tflops * gpu_utilization
effective_flops = effective_tflops * 1e12 # convert to FLOPS
# Single-GPU time
single_gpu_seconds = total_flops / effective_flops
single_gpu_days = single_gpu_seconds / 86400
# Multi-GPU with near-linear scaling
for n_gpus in [256, 1024, 2048, 4096]:
wall_days = single_gpu_days / n_gpus
wall_hours = wall_days * 24
cost = n_gpus * wall_hours * gpu_cost_per_hour
print(f" {n_gpus:>5} GPUs: {wall_days:>6.1f} days, "
f"${cost:>12,.0f}")
return total_flops
print("=== 7B model on 2T tokens ===")
flops = estimate_training_cost(7, 2)
print(f" Total FLOPs: {flops:.1e}\n")
print("=== 70B model on 2T tokens ===")
flops = estimate_training_cost(70, 2)
print(f" Total FLOPs: {flops:.1e}\n")
print("=== 405B model on 15T tokens (LLaMA-3 scale) ===")
flops = estimate_training_cost(405, 15)
print(f" Total FLOPs: {flops:.1e}")
At cloud pricing (~$3/GPU-hour for H100), training a 70B model on 2 trillion tokens costs roughly $1.5 million. LLaMA-3 405B trained on 15 trillion tokens -- that's in the $50-100 million range. Frontier closed-source models (GPT-4 class) are estimated at $100 million or more per training run. The compute budget is the primary constraint, and it's why only a handful of organizations can train frontier models.
Chinchilla scaling laws
In 2022, Hoffmann et al. at DeepMind published a refinement to the scaling laws we discussed in episode #57. The key finding: previous models like GPT-3 were undertrained -- they used too many parameters relative to the amount of training data.
The Chinchilla-optimal ratio: for a given compute budget, parameters and training tokens should scale equally. A model with N parameters should be trained on roughly 20*N tokens.
# Chinchilla-optimal vs actual training
models = [
("GPT-3", 175, 300, 175 * 20),
("Chinchilla", 70, 1400, 70 * 20),
("LLaMA-1", 65, 1400, 65 * 20),
("LLaMA-2", 70, 2000, 70 * 20),
("Mistral 7B", 7, 8000, 7 * 20),
]
print(f"{'Model':<16} {'Params(B)':>10} {'Tokens(B)':>10} "
f"{'Optimal(B)':>11} {'Status':>14}")
print("-" * 65)
for name, params, tokens, optimal in models:
ratio = tokens / optimal
if ratio < 0.5:
status = "UNDERTRAINED"
elif ratio > 2.0:
status = "OVERTRAINED*"
else:
status = "~Optimal"
print(f"{name:<16} {params:>10} {tokens:>10,} "
f"{optimal:>10,} {status:>14}")
print("\n* 'Overtrained' by Chinchilla standards, but intentional:")
print(" Smaller models trained on MORE data are cheaper to deploy.")
print(" Inference cost dominates -- a 7B model served to millions")
print(" of users saves far more money than training costs.")
This insight led to models like LLaMA (65B parameters, 1.4 trillion tokens) that matched or beat the much larger GPT-3 (175B parameters, 300 billion tokens). GPT-3 was massively undertrained by Chinchilla standards -- it had the capacity to absorb far more data than it was given.
But there's a post-Chinchilla twist. Companies like Mistral and Meta now intentionally overtrain small models: Mistral 7B was trained on far more tokens than the Chinchilla-optimal amount. Why? Because training is a one-time cost, but inference (serving the model to millions of users) is ongoing. A smaller model that costs a bit more to train but runs much cheaper at inference time saves money in the long run. The optimal balance depends on your deployment scenario, not just the training compute budget.
Training stability: what goes wrong
Training a large model for weeks or months on thousands of GPUs means any instability can waste millions of dollars. Here's what can go wrong and what you do about it:
Loss spikes: the loss suddenly jumps and may or may not recover. Common causes include bad batches of data (a batch full of garbage text that produces extreme gradients), numerical instability in attention computations, or learning rate issues. The standard mitigation: checkpoint frequently (every few hundred steps), detect spikes automatically, and restart from the last good checkpoint with a slightly lower learning rate.
Gradient explosions: gradients become extremely large, causing weight updates that destroy the model's learned representations. Gradient clipping (capping the gradient norm at a maximum value, typically 1.0) is standard practice. We touched on this back in episode #40 when we discussed training challenges for neural networks -- it's the same principle, just at a much larger scale.
Hardware failures: over a multi-week run on 2000+ GPUs, hardware failures are guaranteed. GPUs fail, network links drop, nodes crash. The training framework must handle this gracefully -- checkpoint regularly, detect failures, redistribute work across remaining healthy nodes, and continue from the last checkpoint without losing too much progress.
# Training loop with stability measures (conceptual)
def stable_training_step(model, batch, optimizer, scaler,
max_grad_norm=1.0, loss_spike_threshold=5.0,
last_loss=None):
"""One training step with stability guardrails."""
optimizer.zero_grad()
# Forward pass in mixed precision
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
logits = model(batch['input_ids'])
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
batch['labels'].view(-1)
)
# Check for loss spike
if last_loss is not None:
ratio = loss.item() / (last_loss + 1e-8)
if ratio > loss_spike_threshold:
print(f"WARNING: Loss spike detected! "
f"{last_loss:.3f} -> {loss.item():.3f} "
f"(ratio: {ratio:.1f}x)")
# Skip this step, don't update weights
return last_loss, True # signal: spike detected
# Check for NaN/Inf
if torch.isnan(loss) or torch.isinf(loss):
print("ERROR: NaN/Inf loss detected! Reverting to checkpoint.")
return last_loss, True
# Backward pass
scaler.scale(loss).backward()
# Gradient clipping (BEFORE optimizer step)
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_grad_norm)
# Check for gradient explosion
if torch.isnan(grad_norm) or grad_norm > 100 * max_grad_norm:
print(f"WARNING: Gradient explosion! norm={grad_norm:.1f}")
optimizer.zero_grad()
return last_loss, True
scaler.step(optimizer)
scaler.update()
return loss.item(), False
# Real training runs also:
# - Checkpoint every N steps (typically every 500-2000)
# - Log loss to W&B/TensorBoard for monitoring
# - Auto-resume from checkpoint on node failure
# - Warm up learning rate over first 1-5% of training
# - Use cosine LR schedule (episode #41)
print("Stability measures for LLM training:")
print(" 1. Gradient clipping (max_norm=1.0)")
print(" 2. Loss spike detection + skip")
print(" 3. NaN/Inf detection + checkpoint rollback")
print(" 4. Frequent checkpointing (every ~1000 steps)")
print(" 5. Learning rate warmup + cosine schedule")
print(" 6. Hardware failure detection + auto-resume")
The learning rate schedule
One more practical detail that's critical at this scale: the learning rate schedule. LLM training universally uses a warmup phase followed by cosine decay (we covered these schedules back in episode #41, but the stakes are much higher here).
import math
def llm_lr_schedule(step, warmup_steps=2000, max_steps=100000,
max_lr=3e-4, min_lr=3e-5):
"""Standard LLM learning rate schedule:
linear warmup -> cosine decay to min_lr."""
if step < warmup_steps:
# Linear warmup
return max_lr * step / warmup_steps
elif step > max_steps:
return min_lr
else:
# Cosine decay
progress = (step - warmup_steps) / (max_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
# Print schedule at key points
print(f"{'Step':>8} {'LR':>12} {'Phase':>15}")
print("-" * 38)
checkpoints = [0, 500, 1000, 2000, 5000, 10000, 25000, 50000, 75000, 100000]
for step in checkpoints:
lr = llm_lr_schedule(step)
if step < 2000:
phase = "warmup"
elif step < 100000:
phase = "cosine decay"
else:
phase = "min LR"
print(f"{step:>8} {lr:>12.6f} {phase:>15}")
The warmup phase is essential. Starting with a high learning rate on randomly initialized weights causes immediate instability. Ramping up gradually over the first 1-5% of training lets the model find a reasonable region of the loss landscape before it starts taking large steps. After warmup, the cosine decay gradually reduces the learning rate, which is crucial for fine-grained convergence in the later stages of training.
What to remember from this one
- LLM training data comes from web crawls, books, code, Wikipedia, and conversations -- mixed in carefully calibrated ratios that dramatically affect model behaviour;
- Data cleaning (deduplication, quality filtering, content filtering) is as important as model architecture -- garbage in, garbage out, but at trillion-token scale;
- Distributed training combines data parallelism (split batches), tensor parallelism (split layers), and pipeline parallelism (split model depth) to spread training across thousands of GPUs;
- Mixed precision (bf16/fp16 compute, fp32 master weights) halves memory and doubles throughput. bfloat16 is preferred because it has the same exponent range as fp32;
- Gradient accumulation simulates large batch sizes without extra GPU memory -- essential when effective batch sizes of 2048+ sequences are needed;
- Training a 70B model costs ~$1.5M; frontier models cost $50-100M+. The
C = 6NDformula lets you estimate compute requirements before spending money; - Chinchilla scaling laws showed that parameters and training tokens should scale together -- GPT-3 was massively undertrained. Post-Chinchilla, small models are intentionally overtrained because inference cost (ongoing) dominates training cost (one-time);
- Training stability is a major engineering challenge: loss spikes, gradient explosions, hardware failures, and numerical issues can waste weeks of compute. Frequent checkpointing, gradient clipping, and automatic recovery are mandatory.
Exercises
Exercise 1: Build a data quality classifier that distinguishes "high quality" from "low quality" text. Create a synthetic dataset: "high quality" samples are well-formed English sentences (from Wikipedia-style text -- generate these yourself with clear subject-verb-object structure, proper capitalization, and meaningful content). "Low quality" samples are random word salads, repetitive text, text with excessive punctuation, and boilerplate patterns. Build a simple bag-of-words classifier (a linear layer on top of word count features: average word length, unique word ratio, punctuation density, uppercase ratio, line break ratio). Train it on 500 samples (250 per class) for 30 epochs. Report accuracy on a held-out test set of 100 samples. Then use the classifier to score 10 new text snippets and print the quality score for each.
Exercise 2: Implement gradient accumulation with a measurable effect. Train two copies of the same small model (a 2-layer transformer language model from episode #56, d_model=128, vocab_size=100) on the same synthetic text data. Train model A with actual batch_size=128, no accumulation. Train model B with actual batch_size=16 and accumulation_steps=8 (effective batch=128). Train both for 500 effective steps (model B does 8x more forward passes). Compare: (a) final loss (should be very similar -- same effective batch), (b) total forward passes, (c) peak memory usage (measure with torch.cuda.max_memory_allocated() if on GPU, or estimate based on batch size). Print a comparison table at the end.
Exercise 3: Simulate pipeline parallelism scheduling. Write a function pipeline_schedule(n_stages, n_microbatches) that prints a time-step diagram showing which micro-batch each pipeline stage processes at each timestep (like the visualization in this episode). Compute and report: (a) total timesteps, (b) total active GPU-timesteps, (c) total idle GPU-timesteps, (d) utilization percentage. Run it for (4 stages, 4 microbatches), (4 stages, 8 microbatches), (4 stages, 16 microbatches), and (8 stages, 16 microbatches). Print all four results in a comparison table. Verify that utilization = n_microbatches / (n_stages + n_microbatches - 1) holds for each case.