Learn AI Series (#49) - LSTM and GRU - Solving the Memory Problem
What will I learn
- You will learn how the LSTM's three gates (forget, input, output) control information flow through a sequence;
- the cell state -- a dedicated highway for long-range information;
- why gating solves the vanishing gradient problem that crippled vanilla RNNs;
- the GRU -- a simpler two-gate alternative that often performs just as well;
- bidirectional RNNs -- looking forward and backward through a sequence;
- stacked RNNs -- adding depth to sequence models;
- building a sentiment classifier with LSTM in PyTorch.
Requirements
- A working modern computer running macOS, Windows or Ubuntu;
- An installed Python 3(.11+) distribution;
- The ambition to learn AI and machine learning.
Difficulty
- Beginner
Curriculum (of the Learn AI Series):
- Learn AI Series (#1) - What Machine Learning Actually Is
- Learn AI Series (#2) - Setting Up Your AI Workbench - Python and NumPy
- Learn AI Series (#3) - Your Data Is Just Numbers - How Machines See the World
- Learn AI Series (#4) - Your First Prediction - No Math, Just Intuition
- Learn AI Series (#5) - Patterns in Data - What "Learning" Actually Looks Like
- Learn AI Series (#6) - From Intuition to Math - Why We Need Formulas
- Learn AI Series (#7) - The Training Loop - See It Work Step by Step
- Learn AI Series (#8) - The Math You Actually Need (Part 1) - Linear Algebra
- Learn AI Series (#9) - The Math You Actually Need (Part 2) - Calculus and Probability
- Learn AI Series (#10) - Your First ML Model - Linear Regression From Scratch
- Learn AI Series (#11) - Making Linear Regression Real
- Learn AI Series (#12) - Classification - Logistic Regression From Scratch
- Learn AI Series (#13) - Evaluation - How to Know If Your Model Actually Works
- Learn AI Series (#14) - Data Preparation - The 80% Nobody Talks About
- Learn AI Series (#15) - Feature Engineering and Selection
- Learn AI Series (#16) - Scikit-Learn - The Standard Library of ML
- Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
- Learn AI Series (#18) - Random Forests - Wisdom of Crowds
- Learn AI Series (#19) - Gradient Boosting - The Kaggle Champion
- Learn AI Series (#20) - Support Vector Machines - Drawing the Perfect Boundary
- Learn AI Series (#21) - Mini Project - Predicting Crypto Market Regimes
- Learn AI Series (#22) - K-Means Clustering - Finding Groups
- Learn AI Series (#23) - Advanced Clustering - Beyond K-Means
- Learn AI Series (#24) - Dimensionality Reduction - PCA
- Learn AI Series (#25) - Advanced Dimensionality Reduction - t-SNE and UMAP
- Learn AI Series (#26) - Anomaly Detection - Finding What Doesn't Belong
- Learn AI Series (#27) - Recommendation Systems - "Users Like You Also Liked..."
- Learn AI Series (#28) - Time Series Fundamentals - When Order Matters
- Learn AI Series (#29) - Time Series Forecasting - Predicting What Comes Next
- Learn AI Series (#30) - Natural Language Processing - Text as Data
- Learn AI Series (#31) - Word Embeddings - Meaning in Numbers
- Learn AI Series (#32) - Bayesian Methods - Thinking in Probabilities
- Learn AI Series (#33) - Ensemble Methods Deep Dive - Stacking and Blending
- Learn AI Series (#34) - ML Engineering - From Notebook to Production
- Learn AI Series (#35) - Data Ethics and Bias in ML
- Learn AI Series (#36) - Mini Project - Complete ML Pipeline
- Learn AI Series (#37) - The Perceptron - Where It All Started
- Learn AI Series (#38) - Neural Networks From Scratch - Forward Pass
- Learn AI Series (#39) - Neural Networks From Scratch - Backpropagation
- Learn AI Series (#40) - Training Neural Networks - Practical Challenges
- Learn AI Series (#41) - Optimization Algorithms - SGD, Momentum, Adam
- Learn AI Series (#42) - PyTorch Fundamentals - Tensors and Autograd
- Learn AI Series (#43) - PyTorch Data and Training
- Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks
- Learn AI Series (#45) - Convolutional Neural Networks - Theory
- Learn AI Series (#46) - CNNs in Practice - Classic to Modern Architectures
- Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer
- Learn AI Series (#48) - Recurrent Neural Networks - Sequences
- Learn AI Series (#49) - LSTM and GRU - Solving the Memory Problem (this post)
Learn AI Series (#49) - LSTM and GRU - Solving the Memory Problem
Solutions to Episode #48 Exercises
Exercise 1: Implement a complete RNN forward and backward pass from scratch in NumPy.
import numpy as np
class VanillaRNN:
def __init__(self, in_d, hid_d, out_d):
s = 0.01
self.Wxh = np.random.randn(in_d, hid_d) * s
self.Whh = np.random.randn(hid_d, hid_d) * s
self.Why = np.random.randn(hid_d, out_d) * s
self.bh = np.zeros(hid_d)
self.by = np.zeros(out_d)
self.hid_d = hid_d
def forward(self, inputs):
self.inputs = inputs
self.hs = [np.zeros(self.hid_d)] # h_0
for x_t in inputs:
h_prev = self.hs[-1]
h_new = np.tanh(x_t @ self.Wxh + h_prev @ self.Whh + self.bh)
self.hs.append(h_new)
y = self.hs[-1] @ self.Why + self.by
return y
def backward(self, d_output):
dWxh = np.zeros_like(self.Wxh)
dWhh = np.zeros_like(self.Whh)
dWhy = np.zeros_like(self.Why)
dbh = np.zeros_like(self.bh)
dby = d_output.copy()
dWhy = self.hs[-1].reshape(-1, 1) @ d_output.reshape(1, -1)
dh = d_output @ self.Why.T
for t in reversed(range(len(self.inputs))):
dtanh = (1 - self.hs[t + 1] ** 2) * dh
dbh += dtanh
dWxh += self.inputs[t].reshape(-1, 1) @ dtanh.reshape(1, -1)
dWhh += self.hs[t].reshape(-1, 1) @ dtanh.reshape(1, -1)
dh = dtanh @ self.Whh.T
return dWxh, dWhh, dWhy, dbh, dby
np.random.seed(42)
rnn = VanillaRNN(in_d=4, hid_d=8, out_d=3)
inputs = [np.random.randn(4) for _ in range(10)]
target = np.random.randn(3)
y = rnn.forward(inputs)
d_output = 2 * (y - target)
dWxh, dWhh, dWhy, dbh, dby = rnn.backward(d_output)
print(f"Output: {y}")
print(f"dWxh norm: {np.linalg.norm(dWxh):.6f}")
print(f"dWhh norm: {np.linalg.norm(dWhh):.6f}")
print(f"dWhy norm: {np.linalg.norm(dWhy):.6f}")
print(f"dbh norm: {np.linalg.norm(dbh):.6f}")
print(f"dby norm: {np.linalg.norm(dby):.6f}")
The backward pass walks through timesteps in reverse, accumulating gradients for the shared weights. The (1 - h^2) factor is the derivative of tanh -- this is where vanishing gradients come from, since it's always between 0 and 1. All gradient norms should be non-zero, confirming that the gradient signal reaches every weight matrix.
Exercise 2: Build a sentiment classifier using nn.RNN with synthetic threshold data.
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
torch.manual_seed(42)
n_samples = 1000
seq_len = 20
X = torch.randn(n_samples, seq_len, 1)
y = (X.squeeze(-1) > 1.5).sum(dim=1) > 3
y = y.long()
X_train, X_test = X[:800], X[800:]
y_train, y_test = y[:800], y[800:]
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
class RNNClassifier(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.RNN(1, 32, batch_first=True)
self.fc = nn.Linear(32, 2)
def forward(self, x):
_, h_n = self.rnn(x)
return self.fc(h_n.squeeze(0))
class FFBaseline(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 2))
def forward(self, x):
return self.fc(x.squeeze(-1))
for ModelClass, name in [(RNNClassifier, "RNN"), (FFBaseline, "Feedforward")]:
model = ModelClass()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(50):
model.train()
for xb, yb in train_loader:
loss = nn.CrossEntropyLoss()(model(xb), yb)
opt.zero_grad()
loss.backward()
opt.step()
model.eval()
with torch.no_grad():
acc = (model(X_test).argmax(1) == y_test).float().mean()
print(f"{name}: test accuracy = {acc:.1%}")
The feedforward baseline can actually perform well here because the task doesn't require remembering order -- it's just counting values above a threshold. The RNN works fine too, but the sequential processing doesn't give it an advantage on this particular task. For tasks where order matters (like the first-element classification in exercise 3), the RNN would have a clear edge.
Exercise 3: Sequence length experiment measuring vanishing gradient impact on memory.
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
def test_memory_span(seq_len, n_train=2000, n_test=500, hidden_dim=64):
torch.manual_seed(42)
X = torch.randn(n_train + n_test, seq_len, 1)
y = (X[:, 0, 0] > 0).long() # label depends ONLY on first element
X_tr, X_te = X[:n_train], X[n_train:]
y_tr, y_te = y[:n_train], y[n_train:]
loader = DataLoader(TensorDataset(X_tr, y_tr), batch_size=64, shuffle=True)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.RNN(1, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, 2)
def forward(self, x):
_, h_n = self.rnn(x)
return self.fc(h_n.squeeze(0))
model = Model()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(40):
model.train()
for xb, yb in loader:
loss = nn.CrossEntropyLoss()(model(xb), yb)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
opt.step()
model.eval()
with torch.no_grad():
acc = (model(X_te).argmax(1) == y_te).float().mean().item()
return acc
print("Vanilla RNN memory span test:")
print(f"{'Length':>8s} {'Accuracy':>10s}")
for length in [10, 25, 50, 100, 200]:
acc = test_memory_span(length)
print(f"{length:>8d} {acc:>10.1%}")
You should see accuracy drop sharply somewhere between 25 and 100 timesteps. At length 10, the RNN remembers the first element easily. At length 200, it's basically guessing (50% accuracy). This is the vanishing gradient in action -- the gradient from the loss at the final timestep can't reach the first timestep through 200 multiplications by W_hh and tanh derivatives. This is precisely why we need LSTMs.
On to today's episode
Last episode we built vanilla RNNs from scratch and ran straight into the wall: the vanishing gradient problem. After about 10-20 timesteps, gradients decay to nothing and the network can't learn long-range dependencies. The memory span experiment in exercise 3 showed this empirically -- accuracy on a simple "remember the first element" task collapses as sequence length grows.
This is a fundamental limitation, not a tuning problem. You can't fix it with better learning rates or smarter initialization. The math is against you: multiplying through the same weight matrix hundreds of times either shrinks your gradient to zero or blows it to infinity. Gradient clipping handles the explosion, but nothing can clip a gradient back into existence once it's vanished.
The solution came from Sepp Hochreiter and Jurgen Schmidhuber in 1997: the Long Short-Term Memory network, or LSTM. The core insight is almost embarrassingly simple -- if information keeps getting destroyed by repeated multiplication through weight matrices, give it a protected path that bypasses those multiplications. A highway for information that can carry it unchanged across hundreds of timesteps, unless the network explicitly decides to modify it.
Here we go!
The cell state: a highway for information
The LSTM introduces a new concept that vanilla RNNs don't have: the cell state (usually written as c_t). Think of it as a conveyor belt running above the hidden state. Information placed on this belt travels forward through time with minimal interference. The cell state is modified only through carefully controlled gates -- learned sigmoid layers that decide what to add, remove, or expose.
The vanilla RNN had one path for information: the hidden state, which gets squashed through tanh at every single timestep. The LSTM has two parallel paths: the cell state (mostly linear, protected) and the hidden state (processed, exposed to the output). The cell state is the long-term memory. The hidden state is the working memory.
This separation is what makes LSTMs work. It's the same principle as ResNet skip connections (episode #46) -- provide a gradient highway that doesn't multiply through weight matrices at every step. During backpropagation, gradients can flow through the cell state with much less decay, allowing the network to learn dependancies spanning hundreds of timesteps.
import torch
import torch.nn as nn
# Comparing information flow: vanilla RNN vs LSTM concept
seq_len = 50
hidden_dim = 16
# Vanilla RNN: information passes through tanh at EVERY step
h = torch.ones(hidden_dim) * 0.5
W = torch.eye(hidden_dim) * 0.95 # slightly contracting
rnn_norms = []
for t in range(seq_len):
h = torch.tanh(h @ W)
rnn_norms.append(h.norm().item())
# LSTM cell state: information passes through multiplication by ~1
c = torch.ones(hidden_dim) * 0.5
forget_gate = torch.ones(hidden_dim) * 0.98 # almost fully open
lstm_norms = []
for t in range(seq_len):
c = forget_gate * c # linear operation, no tanh squashing
lstm_norms.append(c.norm().item())
print(f"Information retention after {seq_len} steps:")
print(f" Vanilla RNN: {rnn_norms[-1]:.6f} (started at {rnn_norms[0]:.6f})")
print(f" LSTM cell: {lstm_norms[-1]:.6f} (started at {lstm_norms[0]:.6f})")
print(f" RNN retained: {rnn_norms[-1]/rnn_norms[0]*100:.1f}%")
print(f" LSTM retained: {lstm_norms[-1]/lstm_norms[0]*100:.1f}%")
The difference is stark. The vanilla RNN's hidden state gets crushed through tanh 50 times -- even with a near-identity weight matrix, the signal decays rapidly. The LSTM's cell state just gets multiplied by a number close to 1, retaining most of the information. This is the gradient highway in action ;-)
The three gates
Three gates control information flow in and out of the cell state. Each gate is a sigmoid layer (output between 0 and 1) that acts as a soft switch -- 0 means "block everything," 1 means "let everything through."
Forget gate (f_t): Looks at the current input and previous hidden state, and decides what to throw away from the cell state. Reading the word "she" after the subject "John" might trigger the forget gate to erase the stored masculine gender information. The forget gate outputs a vector of values between 0 and 1, one per cell state dimension, and the cell state is multiplied element-wise by this vector. Anything multiplied by 0 is forgotten. Anything multiplied by 1 is preserved.
Input gate (i_t): Decides what new information to write to the cell state. It has two parts: a sigmoid layer that decides which dimensions to update, and a tanh layer that creates a vector of candidate values (g_t). The product of these two is added to the cell state. This is how new information enters long-term memory -- selectively, not wholesale.
Output gate (o_t): Decides what parts of the cell state to expose as the hidden state output. The cell state passes through tanh (squashing to [-1, 1]) and is multiplied by the output gate's sigmoid. Not everything the LSTM remembers needs to be relevant right now. The output gate filters the memory to produce only what's useful for the current timestep's prediction.
The full update equations:
f_t = sigmoid(W_f @ [h_{t-1}, x_t] + b_f) # forget gate
i_t = sigmoid(W_i @ [h_{t-1}, x_t] + b_i) # input gate
g_t = tanh(W_g @ [h_{t-1}, x_t] + b_g) # candidate values
o_t = sigmoid(W_o @ [h_{t-1}, x_t] + b_o) # output gate
c_t = f_t * c_{t-1} + i_t * g_t # cell state update
h_t = o_t * tanh(c_t) # hidden state output
import torch
import torch.nn as nn
class LSTMCell(nn.Module):
def __init__(self, in_d, hid_d):
super().__init__()
self.gates = nn.Linear(in_d + hid_d, 4 * hid_d)
self.hid_d = hid_d
def forward(self, x_t, h_prev, c_prev):
combined = torch.cat([x_t, h_prev], dim=-1)
gates = self.gates(combined)
f, i, g, o = gates.chunk(4, dim=-1)
f = torch.sigmoid(f)
i = torch.sigmoid(i)
g = torch.tanh(g)
o = torch.sigmoid(o)
c = f * c_prev + i * g
h = o * torch.tanh(c)
return h, c
# Test our LSTM cell
cell = LSTMCell(in_d=10, hid_d=32)
x = torch.randn(1, 10)
h = torch.zeros(1, 32)
c = torch.zeros(1, 32)
print("Step-by-step LSTM processing:")
for t in range(5):
x_t = torch.randn(1, 10)
h, c = cell(x_t, h, c)
print(f" t={t}: h norm={h.norm():.4f}, c norm={c.norm():.4f}")
Notice the implementation trick: all four gate computations (forget, input, candidate, output) use the same concatenated input -- so we compute them as a single matrix multiplication and split the result. This is a standard optimization you'll see in every LSTM implementation. One big matmul is faster than four separate ones on GPU hardware.
The cell state update c = f * c_prev + i * g is the critical line. The forget gate f controls how much of the old cell state survives. The input gate i controls how much of the new candidate g enters. When f is close to 1 and i is close to 0, the cell state passes through almost unchanged -- the gradient highway is open, and information from hundreds of timesteps ago reaches the current step with minimal decay.
Why gates solve vanishing gradients
In a vanilla RNN, the gradient from timestep t to timestep k flows through (t - k) multiplications by W_hh and tanh derivatives. Each multiplication typically shrinks the gradient, and after 50+ steps it's negligible.
In an LSTM, the gradient can flow through the cell state path. The cell state update is c_t = f_t * c_{t-1} + i_t * g_t. The gradient of c_t with respect to c_{t-1} is simply f_t -- the forget gate value. If the forget gate is close to 1 (which it often is for information the network wants to remember), the gradient passes through with almost no decay. No weight matrix multiplication, no tanh squashing -- just multiplication by a number close to 1.
# Empirical demonstration: gradient flow through LSTM vs vanilla RNN
import torch
import torch.nn as nn
def measure_gradient_flow(model_type, seq_len, hidden_dim=32):
if model_type == "rnn":
model = nn.RNN(1, hidden_dim, batch_first=True)
else:
model = nn.LSTM(1, hidden_dim, batch_first=True)
x = torch.randn(1, seq_len, 1, requires_grad=True)
output, _ = model(x)
# Loss on final timestep
loss = output[0, -1, :].sum()
loss.backward()
# Gradient at the input tells us how much the first timestep
# influences the final output
grad_first = x.grad[0, 0, :].abs().mean().item()
grad_last = x.grad[0, -1, :].abs().mean().item()
return grad_first, grad_last
print(f"{'Model':>6s} {'Length':>8s} {'Grad(t=0)':>12s} {'Grad(t=T)':>12s} {'Ratio':>8s}")
for length in [10, 50, 100, 200]:
for model_type in ["rnn", "lstm"]:
g_first, g_last = measure_gradient_flow(model_type, length)
ratio = g_first / g_last if g_last > 0 else 0
print(f"{model_type:>6s} {length:>8d} {g_first:>12.6f} {g_last:>12.6f} {ratio:>8.4f}")
print()
This is the key insight: the forget gate gives the network a learned shortcut for gradient flow. It's not that the gradient never decays -- it decays when the network decides to forget (f close to 0). But for information that matters, the network learns to keep the forget gate open, creating a gradient path with minimal loss across the entire sequence.
In practice, LSTM gradients are stable enough to learn dependencies spanning 100-200+ timesteps. Not infinite -- there's still some decay, and the gates themselves introduce some gradient loss -- but orders of magnitude better than vanilla RNNs.
GRU: a simpler alternative
In 2014, Kyunghyun Cho and colleagues proposed the Gated Recurrent Unit (GRU), which achieves similar performance to the LSTM with a simpler architecture. The GRU has two gates in stead of three, and no separate cell state -- it merges the cell state and hidden state into a single vector.
Reset gate (r_t): Controls how much of the previous hidden state to ignore when computing the candidate. A reset gate of 0 means "pretend you've never seen anything before" -- useful when the sequence has an abrupt topic change. A reset gate of 1 means "use all previous context."
Update gate (z_t): Combines the forget and input gates into a single decision. It's a value between 0 and 1 that controls the interpolation between the previous hidden state and the new candidate: h_t = z * h_{t-1} + (1 - z) * candidate. When z is close to 1, the hidden state is preserved (like an LSTM forget gate of 1 with input gate of 0). When z is close to 0, the hidden state is replaced with the new candidate.
The elegance is in that complementary structure: whatever fraction of the old state you keep (z), you fill the rest (1 - z) with new information. The LSTM decides independently how much to forget and how much to add -- the GRU ties these two decisions together. Fewer parameters, fewer decisions, and in practice the constraint rarely hurts.
class GRUCell(nn.Module):
def __init__(self, in_d, hid_d):
super().__init__()
self.rz = nn.Linear(in_d + hid_d, 2 * hid_d)
self.n = nn.Linear(in_d + hid_d, hid_d)
self.hid_d = hid_d
def forward(self, x_t, h_prev):
rz = torch.sigmoid(self.rz(torch.cat([x_t, h_prev], -1)))
r, z = rz.chunk(2, dim=-1)
candidate = torch.tanh(self.n(torch.cat([x_t, r * h_prev], -1)))
return z * h_prev + (1 - z) * candidate
# Compare parameter counts
in_d, hid_d = 64, 128
lstm_params = 4 * (in_d + hid_d) * hid_d + 4 * hid_d # 4 gates
gru_params = 3 * (in_d + hid_d) * hid_d + 3 * hid_d # 3 gates (rz = 2, n = 1)
print(f"LSTM parameters: {lstm_params:,}")
print(f"GRU parameters: {gru_params:,}")
print(f"GRU is {(1 - gru_params/lstm_params)*100:.0f}% fewer parameters")
LSTM vs GRU is one of those debates that has been studied exhaustively with no clear winner. On most benchmarks, they perform within a percentage point of each other. GRUs train slightly faster (fewer parameters) and work well on smaller datasets. LSTMs have a slight edge on tasks requiring very fine-grained control over what to remember and forget -- like tasks where you need to remember one specific fact from the start of a long sequence while overwriting everything else. In practice, try both and pick whichever works better for your specific task. The difference is rarely significant enough to agonize over.
Bidirectional RNNs
A standard RNN reads the sequence left to right. But for many tasks, the future context matters just as much as the past. In the sentence "The bank of the river was ___," the word "river" (which comes after "bank") tells you this is about a waterway, not a financial institution. A left-to-right RNN can't use that context when processing "bank."
A bidirectional RNN runs two separate RNNs: one forward (left to right) and one backward (right to left). At each timestep, the outputs of both directions are concatenated. This gives each position access to both past and future context.
# Bidirectional LSTM in PyTorch
lstm_uni = nn.LSTM(input_size=64, hidden_size=128, num_layers=2,
batch_first=True, bidirectional=False)
lstm_bi = nn.LSTM(input_size=64, hidden_size=128, num_layers=2,
batch_first=True, bidirectional=True)
x = torch.randn(8, 50, 64) # batch=8, seq_len=50, features=64
out_uni, (h_uni, c_uni) = lstm_uni(x)
out_bi, (h_bi, c_bi) = lstm_bi(x)
print(f"Unidirectional LSTM:")
print(f" Output: {out_uni.shape}") # (8, 50, 128)
print(f" Hidden: {h_uni.shape}") # (2, 8, 128) -- 2 layers
print(f" Params: {sum(p.numel() for p in lstm_uni.parameters()):,}")
print(f"\nBidirectional LSTM:")
print(f" Output: {out_bi.shape}") # (8, 50, 256) -- forward+backward concat
print(f" Hidden: {h_bi.shape}") # (4, 8, 128) -- 2 layers x 2 directions
print(f" Params: {sum(p.numel() for p in lstm_bi.parameters()):,}")
In PyTorch, bidirectional is a single flag: bidirectional=True. The output hidden dimension doubles (forward and backward concatenated). The cost is that you need the full sequence upfront -- you can't use bidirectional RNNs for real-time streaming or autoregressive generation, since the backward pass requires the future to already be known.
Bidirectional LSTMs are the backbone of many NLP systems. Named entity recognition ("is this word a person, place, or organization?") benefits enormously from seeing context on both sides. Part-of-speech tagging similarly needs forward and backward context. For tasks where you process a complete input and produce labels -- rather than generating one token at a time -- bidirectional is almost always better than unidirectional.
Stacked RNNs
Just as we can stack convolutional layers to build hierarchical feature extractors (episodes #45-46), we can stack RNN layers. The output sequence of one RNN layer becomes the input sequence of the next. Each layer captures increasingly abstract patterns: the first layer might learn character-level or word-level patterns, the second layer might learn phrase-level patterns, and so on.
# Stacking layers and using dropout between them
lstm_deep = nn.LSTM(input_size=64, hidden_size=128, num_layers=3,
batch_first=True, dropout=0.3)
x = torch.randn(4, 100, 64) # batch=4, seq_len=100, features=64
lstm_deep.train() # dropout only active during training
out, (h_n, c_n) = lstm_deep(x)
print(f"3-layer LSTM:")
print(f" Output: {out.shape}") # (4, 100, 128) -- last layer only
print(f" Hidden: {h_n.shape}") # (3, 4, 128) -- one per layer
print(f" Cell: {c_n.shape}") # (3, 4, 128) -- one per layer
print(f" Params: {sum(p.numel() for p in lstm_deep.parameters()):,}")
# Access hidden states from each layer
for layer in range(3):
print(f" Layer {layer} final hidden norm: {h_n[layer].norm():.4f}")
Two to three layers is typical. Beyond that, training becomes difficult (even with LSTMs) and the returns diminish. If you need more capacity, increase the hidden size rather than the depth. The most famous deep stacked LSTM was Google's Neural Machine Translation system (2016) with 8 layers -- and even they needed residual connections between layers to train it successfully. The dropout=0.3 argument applies dropout between layers (but NOT within the recurrent computation), which is the standard regularization technique for stacked RNNs.
Practical: sentiment classification with LSTM
Let's build a complete sentiment classifier. This is the workhorse architecture that dominated text classification from roughly 2015 to 2019, before transformers took over. We'll use PyTorch's built-in LSTM with bidirectional processing and dropout:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
class SentimentLSTM(nn.Module):
def __init__(self, vocab_sz, emb_d=64, hid_d=128, n_layers=2):
super().__init__()
self.embed = nn.Embedding(vocab_sz, emb_d, padding_idx=0)
self.lstm = nn.LSTM(emb_d, hid_d, n_layers,
batch_first=True, bidirectional=True, dropout=0.3)
self.fc = nn.Linear(hid_d * 2, 1) # *2 for bidirectional
self.drop = nn.Dropout(0.5)
def forward(self, x):
emb = self.drop(self.embed(x))
out, (h_n, _) = self.lstm(emb)
# Concat last forward and backward hidden states
hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)
return self.fc(self.drop(hidden)).squeeze(1)
# Synthetic demo -- in practice you'd use real text data
torch.manual_seed(42)
vocab_sz = 5000
n_samples = 1000
seq_len = 100
X = torch.randint(1, vocab_sz, (n_samples, seq_len)) # 0 reserved for padding
y = torch.randint(0, 2, (n_samples,)).float()
X_train, X_test = X[:800], X[800:]
y_train, y_test = y[:800], y[800:]
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True)
model = SentimentLSTM(vocab_sz)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(10):
model.train()
for xb, yb in train_loader:
logits = model(xb)
loss = nn.BCEWithLogitsLoss()(logits, yb)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()
if epoch % 3 == 0:
model.eval()
with torch.no_grad():
preds = (model(X_test) > 0).float()
acc = (preds == y_test).float().mean()
print(f"Epoch {epoch}: test acc = {acc:.1%}")
print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Architecture: Embedding -> BiLSTM(2 layers) -> Dropout -> Linear")
A few design choices worth noting. We apply dropout to the embedding layer -- this is a common regularization technique that randomly zeroes out entire word representations during training. The padding_idx=0 argument tells the embedding layer to keep the vector for index 0 as all zeros, which is useful when you pad sequences to the same length (the padding tokens shouldn't contribute any meaning).
For the final classification, we concatenate the last hidden state from the forward direction (h_n[-2]) with the last hidden state from the backward direction (h_n[-1]). The forward state has seen the entire sequence left to right; the backward state has seen it right to left. Together they give the classifier full bidirectional context. We apply dropout again before the final linear layer -- this is the single most effective regularization for recurrent models in my experience.
On random data this won't converge to anything meaningful (there's no pattern to learn). On real sentiment data -- movie reviews, product ratings, tweets -- this architecture routinely achieves 85-90% accuracy with a few thousand labeled examples. For small to medium datasets or when you need a lightweight model, it remains a solid choice even now ;-)
LSTM vs GRU: head-to-head comparison
Let's put them side by side on the same task -- the "remember the first element" problem from episode #48's exercise 3, but now with long sequences where vanilla RNNs fail:
def compare_architectures(seq_len=100, hidden_dim=64, epochs=30):
torch.manual_seed(42)
n_train, n_test = 2000, 500
X = torch.randn(n_train + n_test, seq_len, 1)
y = (X[:, 0, 0] > 0).long() # label depends on FIRST element only
X_tr, X_te = X[:n_train], X[n_train:]
y_tr, y_te = y[:n_train], y[n_train:]
loader = DataLoader(TensorDataset(X_tr, y_tr), batch_size=64, shuffle=True)
results = {}
for name, rnn_class in [("RNN", nn.RNN), ("LSTM", nn.LSTM), ("GRU", nn.GRU)]:
class Model(nn.Module):
def __init__(self):
super().__init__()
self.rnn = rnn_class(1, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, 2)
def forward(self, x):
out, h = self.rnn(x)
if isinstance(h, tuple):
h = h[0] # LSTM returns (h, c)
return self.fc(h.squeeze(0))
model = Model()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(epochs):
model.train()
for xb, yb in loader:
loss = nn.CrossEntropyLoss()(model(xb), yb)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
opt.step()
model.eval()
with torch.no_grad():
acc = (model(X_te).argmax(1) == y_te).float().mean().item()
results[name] = acc
return results
print(f"Remember-the-first-element task (sequence length = 100):")
results = compare_architectures(seq_len=100)
for name, acc in results.items():
marker = " <-- fails" if acc < 0.6 else ""
print(f" {name:>4s}: {acc:.1%}{marker}")
The vanilla RNN should struggle (close to 50% -- random guessing), while both LSTM and GRU should achieve 90%+ accuracy. On a sequence of length 100, the RNN simply cannot propagate the gradient from the loss back to the first timestep. The LSTM's cell state and the GRU's update gate both provide the gradient highway that makes learning possible across this distance.
When to use what
Having said that, here's a practical decision framework for sequence models in 2026:
Use LSTM when: you need fine-grained memory control, you have plenty of data and compute, the task requires remembering very specific facts over long distances (like a specific number from the beginning of a document), or you're working in a domain where LSTM is the established baseline.
Use GRU when: you want faster training (fewer parameters), your dataset is small to medium, you're prototyping and need quick iterations, or the task doesn't require the LSTM's extra gating complexity.
Use neither (use transformers) when: you have large datasets, you need to capture very long-range dependencies (1000+ tokens), you need parallelism during training, or you're working on any NLP task where a pretrained transformer model exists. We'll build transformers starting in a few episodes.
Still use LSTM/GRU when: you need streaming / real-time processing (transformers need the full sequence), you're deploying on edge devices with tight memory budgets, or you're processing genuinly sequential data like continuous sensor streams where each new reading must be processed immediately.
The short version
- The LSTM adds a cell state (long-term memory highway) alongside the hidden state (working memory);
- Three gates control the cell state: forget (what to erase), input (what to write), output (what to expose);
- Gradients flow through the cell state path with minimal decay -- the forget gate value is the gradient multiplier, not a full weight matrix;
- The GRU simplifies to two gates (reset, update) with no separate cell state -- performance is comparable to LSTM on most tasks;
- Bidirectional RNNs run forward and backward passes, giving each position access to full sequence context (but can't be used for generation);
- Stacking 2-3 LSTM/GRU layers captures hierarchical patterns -- more layers need residual connections;
- The standard text classification recipe: embedding -> bidirectional stacked LSTM -> dropout -> linear.
We've now covered the complete evolution of recurrent architectures: vanilla RNNs (episode #48) and gated variants (today). These models dominated sequence processing for years, but they share a fundamental limitation: they process sequences one step at a time. For a 500-word sentence, the RNN/LSTM/GRU must run 500 sequential steps -- no parallelism possible. Training on large datasets is slow because you can't fully utilize modern GPU hardware designed for parallel computation. There's an architectural idea that solves this -- processing all positions simultaneously instead of sequentially, using a mechanism that lets each position "attend" to every other position directly. That's coming up soon ;-)
Exercises
Exercise 1: Implement the LSTMCell class from this episode and use it to process a sequence of 50 timesteps with input_dim=8 and hidden_dim=32. Track the cell state norm and hidden state norm at each timestep and print them at steps 0, 10, 20, 30, 40, 49. Then do the same with a vanilla RNN cell (just the h = tanh(x @ Wxh + h @ Whh) update). Compare how the norms evolve -- the LSTM should maintain more stable norms over time because of the gated cell state.
Exercise 2: Build a sequence copying task to demonstrate LSTM's long-range memory. Generate sequences of length 40 where the first 5 elements are random integers 1-9, followed by 30 zeros (padding), followed by 5 positions where the model must reproduce the first 5 elements. Train an LSTM and a vanilla RNN on this task. Report accuracy for each -- the LSTM should handle the 30-step memory gap while the vanilla RNN struggles.
Exercise 3: Compare unidirectional vs bidirectional LSTM on a synthetic tagging task. Generate sequences of length 20 where each element is a random float. The label for each position is 1 if both its left neighbor AND right neighbor are positive, else 0 (positions 0 and 19 are always 0). Train both unidirectional and bidirectional LSTMs. The bidirectional model should outperform because the label depends on future context that the unidirectional model can't see.