Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
What will I learn
- You will learn how decision trees split data by asking yes/no questions about features -- a radically different approach from the linear models in episodes #10-12;
- Gini impurity and information gain -- the math behind how trees choose the best split at each step;
- how to build a complete decision tree classifier from scratch in pure NumPy, using recursive splitting;
- how to visualize and interpret a tree's learned structure (something you can NOT do this clearly with linear models);
- why unlimited trees memorize training data and how pruning parameters control overfitting;
- decision trees for regression -- same idea, different splitting criterion;
- scikit-learn's DecisionTreeClassifier and DecisionTreeRegressor with the Pipeline tools from episode #16;
- feature importances -- how trees tell you which features matter most;
- when trees beat linear models and when they don't -- and why that matters for choosing your next model.
Requirements
- A working modern computer running macOS, Windows or Ubuntu;
- An installed Python 3(.11+) distribution;
- The ambition to learn AI and machine learning.
Difficulty
- Beginner
Curriculum (of the Learn AI Series):
- Learn AI Series (#1) - What Machine Learning Actually Is
- Learn AI Series (#2) - Setting Up Your AI Workbench - Python and NumPy
- Learn AI Series (#3) - Your Data Is Just Numbers - How Machines See the World
- Learn AI Series (#4) - Your First Prediction - No Math, Just Intuition
- Learn AI Series (#5) - Patterns in Data - What "Learning" Actually Looks Like
- Learn AI Series (#6) - From Intuition to Math - Why We Need Formulas
- Learn AI Series (#7) - The Training Loop - See It Work Step by Step
- Learn AI Series (#8) - The Math You Actually Need (Part 1) - Linear Algebra
- Learn AI Series (#9) - The Math You Actually Need (Part 2) - Calculus and Probability
- Learn AI Series (#10) - Your First ML Model - Linear Regression From Scratch
- Learn AI Series (#11) - Making Linear Regression Real
- Learn AI Series (#12) - Classification - Logistic Regression From Scratch
- Learn AI Series (#13) - Evaluation - How to Know If Your Model Actually Works
- Learn AI Series (#14) - Data Preparation - The 80% Nobody Talks About
- Learn AI Series (#15) - Feature Engineering and Selection
- Learn AI Series (#16) - Scikit-Learn - The Standard Library of ML
- Learn AI Series (#17) - Decision Trees - How Machines Make Decisions (this post)
Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
At the end of episode #16 I said we were moving into new territory. We'd covered linear models thoroughly -- regression from scratch in #10, regularization and the normal equation in #11, logistic regression in #12, evaluation in #13, data preparation in #14, feature engineering in #15, and the full scikit-learn toolkit in #16. Six episodes building and then professionally tooling the linear model family. And I said the next part of the journey takes us into models that think very differently.
Today we arrive at the first of those fundamentally different models. Decision trees don't draw lines through your data. They don't compute weighted sums. They don't need gradient descent. They don't even need your features to be scaled (remember how much time we spent on StandardScaler in episodes #11, #14 and #16?). In stead, they ask a series of yes/no questions about your features and use the answers to arrive at a prediction.
"Is the apartment bigger than 80 square meters?" Yes. "Is it on a floor higher than 5?" No. "Is the building older than 20 years?" Yes. Prediction: EUR 195,000.
That's it. That's a decision tree. A nested sequence of if/else conditions that split your data into increasingly specific groups until each group is pure enough (or small enough) to make a prediction. It sounds almost too simple to be useful. And yet trees are the foundation of some of the most powerful ML algorithms that exist today -- algorithms that consistently win data science competitions and power production systems at scale. Understanding the single tree is the key to understanding that entire family ;-)
Let's build one from scratch.
The intuition: playing twenty questions with data
Think of the children's game "Twenty Questions." You're trying to guess what someone is thinking of by asking yes/no questions. A good player doesn't ask random questions -- they ask questions that eliminate the most possibilities each time. "Is it alive?" eliminates roughly half the universe. "Is it a blue thing made in Sweden?" eliminates almost nothing for most targets.
Decision trees do exactly this with data. At each step, the tree finds the question (which feature? which threshold?) that best separates the data into groups with similar outcomes. The "best" question is the one that produces the most homogeneous groups -- groups where most samples share the same label.
Let's see this concretely with our apartment dataset:
import numpy as np
np.random.seed(42)
# Apartment data: classify as high or low price
n = 200
sqm = np.random.uniform(30, 150, n)
floor = np.random.randint(0, 10, n).astype(float)
age = np.random.uniform(0, 50, n)
price = 2500 * sqm + 500 * floor - 200 * age + np.random.randn(n) * 20000
high_price = (price > np.median(price)).astype(int)
X = np.column_stack([sqm, floor, age])
feature_names = ["sqm", "floor", "age"]
# What's a good first question?
for feat_idx, feat_name in enumerate(feature_names):
median_val = np.median(X[:, feat_idx])
left_mask = X[:, feat_idx] <= median_val
left_purity = np.mean(high_price[left_mask])
right_purity = np.mean(high_price[~left_mask])
print(f"Split on {feat_name} <= {median_val:.1f}:")
print(f" Left: {left_mask.sum()} samples, {left_purity:.0%} high price")
print(f" Right: (~left_mask).sum() samples, {right_purity:.0%} high price")
print()
The best split is the one that creates the most lopsided groups -- ideally one group is mostly "high price" and the other is mostly "low price." A split that produces two 50/50 groups is useless; it didn't help us separate anything. A split that produces one 90% group and one 10% group is great -- after that split, we're much more certain about the prediction in each branch.
But how do we measure "how good" a split is, mathematically? We need a metric for impurity.
Gini impurity: measuring the mix
A pure node contains only one class -- all high price, or all low price. An impure node is a mix. We need a number that says "how mixed is this group?" and that equals zero when the group is perfectly pure.
Gini impurity does exactly that:
Gini(node) = 1 - sum(p_i^2)
Where p_i is the proportion of class i in the node. For binary classification with classes 0 and 1:
Gini = 1 - p_0^2 - p_1^2
If all samples are class 0: Gini = 1 - 1^2 - 0^2 = 0. Pure.
If the split is 50/50: Gini = 1 - 0.5^2 - 0.5^2 = 0.5. Maximum impurity for binary.
If 90% are class 0: Gini = 1 - 0.9^2 - 0.1^2 = 0.18. Mostly pure.
Let's implement it:
def gini_impurity(y):
"""Compute Gini impurity of a set of labels."""
if len(y) == 0:
return 0
classes, counts = np.unique(y, return_counts=True)
proportions = counts / len(y)
return 1 - np.sum(proportions ** 2)
# Examples
print("Gini impurity examples:")
print(f" All class 0 [0,0,0,0]: {gini_impurity([0,0,0,0]):.3f}")
print(f" All class 1 [1,1,1,1]: {gini_impurity([1,1,1,1]):.3f}")
print(f" Perfect 50/50 [0,0,1,1]: {gini_impurity([0,0,1,1]):.3f}")
print(f" Mostly one [0,0,0,1]: {gini_impurity([0,0,0,1]):.3f}")
print(f" Three classes [0,1,2]: {gini_impurity([0,1,2]):.3f}")
Gini = 0 means the node is perfectly pure. The closer to 0.5 (for binary), the more mixed. This is the metric we'll use to evaluate how good a split is -- we want to reduce Gini impurity as much as possible with each split.
There's another popular criterion called entropy (from information theory) that works similarly. Entropy is - sum(p_i * log2(p_i)) and measures "information content" of the mix. In practice, Gini and entropy produce nearly identical trees. Scikit-learn uses Gini by default, so we'll stick with that.
Finding the best split: information gain
For each feature and each possible threshold, we evaluate the Gini impurity of the two child nodes (left = samples where feature <= threshold, right = the rest). The split with the highest information gain wins.
Information gain = parent Gini minus the weighted average Gini of the children:
gain = Gini(parent) - (n_left/n * Gini(left) + n_right/n * Gini(right))
Higher gain means the split did a better job of separating the classes. Let's implement it:
def find_best_split(X, y):
"""Find the feature and threshold that maximize information gain."""
best_gain = -1
best_feature = None
best_threshold = None
parent_gini = gini_impurity(y)
n = len(y)
for feature_idx in range(X.shape[1]):
# Try every unique value as a threshold
thresholds = np.unique(X[:, feature_idx])
for threshold in thresholds:
left_mask = X[:, feature_idx] <= threshold
right_mask = ~left_mask
if left_mask.sum() == 0 or right_mask.sum() == 0:
continue
# Weighted average Gini of children
left_gini = gini_impurity(y[left_mask])
right_gini = gini_impurity(y[right_mask])
n_left = left_mask.sum()
n_right = right_mask.sum()
weighted_child_gini = (n_left * left_gini + n_right * right_gini) / n
gain = parent_gini - weighted_child_gini
if gain > best_gain:
best_gain = gain
best_feature = feature_idx
best_threshold = threshold
return best_feature, best_threshold, best_gain
feat, thresh, gain = find_best_split(X, high_price)
print(f"Best first split: {feature_names[feat]} <= {thresh:.1f}")
print(f"Information gain: {gain:.4f}")
print(f"Parent Gini: {gini_impurity(high_price):.4f}")
Run that and you'll see the tree picks the feature and threshold that maximize information gain. For our apartment data, it'll almost certainly pick sqm as the first split, because apartment size is the strongest predictor of price. That makes intuitive sense -- and the math confirms it.
Having said that, the algorithm doesn't know "square meters means apartment size." It doesn't know anything about apartments. It just tries every feature at every threshold and picks the one that separates the labels best. That's the beauty (and the limitation) of this approach -- no domain knowledge required, but also no domain knowledge used.
Building the tree recursively
A single split isn't a tree. It's a stump. A real decision tree repeats the splitting process on each child node: find the best split, divide the data, find the best split in each half, divide again, and so on. This is a naturally recursive process -- each node either makes a split (internal node) or makes a prediction (leaf node).
We stop splitting when one of these conditions is met:
- The node is perfectly pure (all samples have the same label)
- We've reached the maximum depth allowed
- The node has fewer samples than the minimum required for a split
- No split produces positive information gain
class TreeNode:
"""A single node in a decision tree."""
def __init__(self, feature=None, threshold=None,
left=None, right=None, value=None):
self.feature = feature # which feature to split on
self.threshold = threshold # threshold for the split
self.left = left # left child (feature <= threshold)
self.right = right # right child (feature > threshold)
self.value = value # prediction (only for leaf nodes)
def build_tree(X, y, depth=0, max_depth=5, min_samples=2):
"""Recursively build a decision tree."""
# Base case 1: max depth reached
# Base case 2: too few samples to split
# Base case 3: node is already pure
if (depth >= max_depth or
len(y) < min_samples or
len(np.unique(y)) == 1):
# Make a leaf: predict the majority class
classes, counts = np.unique(y, return_counts=True)
return TreeNode(value=classes[counts.argmax()])
# Find the best split
feature, threshold, gain = find_best_split(X, y)
# Base case 4: no useful split exists
if gain <= 0:
classes, counts = np.unique(y, return_counts=True)
return TreeNode(value=classes[counts.argmax()])
# Split the data
left_mask = X[:, feature] <= threshold
# Recurse on each child
left_child = build_tree(
X[left_mask], y[left_mask],
depth + 1, max_depth, min_samples
)
right_child = build_tree(
X[~left_mask], y[~left_mask],
depth + 1, max_depth, min_samples
)
return TreeNode(
feature=feature, threshold=threshold,
left=left_child, right=right_child
)
# Build it!
tree = build_tree(X, high_price, max_depth=4)
print("Tree built successfully!")
That's the entire training algorithm for a decision tree classifier. No gradient descent. No learning rate. No iterative optimization. Just recursive splitting based on information gain. Compare that to the training loop we built in episode #7 and used through episodes #10-12 -- the tree has no loop at all. It finds the optimal split at each node in a single pass through the data.
If you've been through the Learn Python Series (specifically the recursion episode, if you followed it), this recursive structure should feel familiar. Each call to build_tree either returns a leaf or creates an internal node that points to two child subtrees. The base cases prevent infinite recursion.
Making predictions
Prediction is even simpler than training. Start at the root. Check the condition. Go left or right. Repeat until you reach a leaf. Return the leaf's value.
def predict_one(node, x):
"""Predict a single sample by traversing the tree."""
# Leaf node: return the prediction
if node.value is not None:
return node.value
# Internal node: go left or right
if x[node.feature] <= node.threshold:
return predict_one(node.left, x)
return predict_one(node.right, x)
def predict(node, X):
"""Predict for an array of samples."""
return np.array([predict_one(node, x) for x in X])
# Test on training data
train_preds = predict(tree, X)
train_accuracy = np.mean(train_preds == high_price)
print(f"Training accuracy: {train_accuracy:.1%}")
Notice how fast prediction is. For each sample, you traverse at most max_depth nodes. With max_depth=4, that's at most 4 comparisons per sample. Compare that to a linear model which computes a weighted sum of ALL features -- for 1000 features, that's 1000 multiplications and additions. Trees are extremely efficient at prediction time. This matters in production systems where you need to classify millions of requests per second.
Visualizing the learned tree
One of the biggest advantages of decision trees over linear models (and ESPECIALLY over neural networks, which we'll encounter later in this series) is interpretability. You can literally print the tree and read it like a set of business rules. Try doing that with a neural network ;-)
def print_tree(node, depth=0, prefix="Root"):
"""Print the tree structure in a human-readable format."""
indent = " " * depth
if node.value is not None:
print(f"{indent}{prefix} -> Predict: {'HIGH' if node.value else 'LOW'}")
return
fname = feature_names[node.feature]
print(f"{indent}{prefix}: {fname} <= {node.threshold:.1f}?")
print_tree(node.left, depth + 1, "Yes")
print_tree(node.right, depth + 1, "No ")
print("\nLearned decision tree:")
print("=" * 50)
print_tree(tree)
You can follow the logic: "If sqm > 90 AND floor > 3 AND age < 15, predict HIGH price." Try explaining a logistic regression model that clearly. You CAN inspect the coefficients (we did that in episodes #10-12), but "this feature has weight 2,347" is a lot less intuitive than "if square meters exceed 90, go left."
This interpretability is a real practical advantage. In domains like healthcare, finance, and law, people need to understand WHY a model made a decision. A doctor can't tell a patient "the neural network says you need surgery." But "the model found that your blood pressure is above 160 AND your cholesterol is above 240 AND you're over 60, which puts you in the high-risk group" -- that's a conversation a doctor can have. And that's exactly what a decision tree provides.
The overfitting trap: unlimited depth
Here's where things get interesting -- and where the "too good to be true" alarm should go off. Let me show you something:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
X, high_price, test_size=0.2, random_state=42
)
# Build an UNLIMITED tree (no max_depth)
tree_unlimited = build_tree(X_train, y_train, max_depth=100, min_samples=1)
train_acc = np.mean(predict(tree_unlimited, X_train) == y_train)
test_acc = np.mean(predict(tree_unlimited, X_test) == y_test)
print("Unlimited tree:")
print(f" Train accuracy: {train_acc:.1%}")
print(f" Test accuracy: {test_acc:.1%}")
# Build a PRUNED tree (max_depth=3)
tree_pruned = build_tree(X_train, y_train, max_depth=3, min_samples=5)
train_acc_p = np.mean(predict(tree_pruned, X_train) == y_train)
test_acc_p = np.mean(predict(tree_pruned, X_test) == y_test)
print("\nPruned tree (max_depth=3, min_samples=5):")
print(f" Train accuracy: {train_acc_p:.1%}")
print(f" Test accuracy: {test_acc_p:.1%}")
The unlimited tree gets 100% on training data. Perfect. Every single training sample classified correctly. Sounds amazing, right?
It's not. The tree has memorized the training data. It created enough leaves to give each training sample its own personal leaf. Every noise pattern, every random fluctuation, every outlier -- all captured and encoded as if they were real signal. On the test data, where those specific noise patterns don't repeat, the performance drops.
This is EXACTLY the overfitting problem we first saw in episode #11 with high-degree polynomials. A degree-11 polynomial perfectly passes through 12 data points but wobbles wildly between them. An unlimited tree perfectly classifies all training points but creates bizarre decision boundaries in the gaps. Same problem, different algorithm.
The pruned tree sacrifices some training accuracy -- it can't perfectly classify every training sample with only 3 levels of depth. But it generalizes better because it captures the broad patterns without memorizing the noise. That's the bias-variance tradeoff in action, same concept from episode #11, and it's going to follow us through every algorithm in this series.
Controlling overfitting: the pruning toolkit
The beauty of decision trees is that controlling their complexity is intuitive. You have several knobs to turn, and each one limits the tree's ability to memorize:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score
# Let's see how max_depth affects generalization
print("Effect of max_depth on generalization:\n")
print(f"{'depth':>10s} {'CV mean':>10s} {'CV std':>10s}")
print("-" * 34)
for depth in [1, 2, 3, 4, 5, 8, 12, None]:
clf = DecisionTreeClassifier(max_depth=depth, random_state=42)
scores = cross_val_score(clf, X, high_price, cv=5)
label = str(depth) if depth else "unlimited"
print(f"{label:>10s} {scores.mean():>10.3f} {scores.std():>10.3f}")
Watch the CV accuracy as depth increases. It usually climbs initially (more depth = more capacity to learn real patterns), peaks somewhere around depth 3-5, then either plateaus or drops as the tree starts overfitting. The "best" depth is wherever CV accuracy is highest -- and cross_val_score from episode #16 makes finding it trivial.
Beyond max_depth, sklearn's DecisionTreeClassifier gives you several other pruning parameters:
# The full pruning toolkit
print("\nPruning parameter comparison:\n")
configs = {
"No limits": DecisionTreeClassifier(random_state=42),
"max_depth=4": DecisionTreeClassifier(max_depth=4, random_state=42),
"min_samples_split=10": DecisionTreeClassifier(min_samples_split=10, random_state=42),
"min_samples_leaf=5": DecisionTreeClassifier(min_samples_leaf=5, random_state=42),
"max_leaf_nodes=10": DecisionTreeClassifier(max_leaf_nodes=10, random_state=42),
"Combined": DecisionTreeClassifier(
max_depth=5, min_samples_leaf=5,
random_state=42),
}
print(f"{'Config':>25s} {'CV mean':>10s} {'CV std':>10s}")
print("-" * 49)
for name, clf in configs.items():
scores = cross_val_score(clf, X, high_price, cv=5)
print(f"{name:>25s} {scores.mean():>10.3f} {scores.std():>10.3f}")
Here's what each parameter does:
max_depth: hard limit on tree depth. Most important single parameter.min_samples_split: minimum samples needed to attempt a split. Higher = more conservative.min_samples_leaf: minimum samples in each leaf. Prevents leaves with just 1-2 samples (which are almost certainly noise).max_leaf_nodes: directly limits how many predictions the tree can make.max_features: random subset of features to consider at each split. This is a preview of something we'll use heavily very soon -- limiting the features each split can see introduces useful randomness.
In practice, I usually start with max_depth and min_samples_leaf and tune them with cross-validation. The GridSearchCV from episode #16 makes this automatic:
from sklearn.model_selection import GridSearchCV
param_grid = {
'max_depth': [2, 3, 4, 5, 6, 8, 10],
'min_samples_leaf': [1, 3, 5, 10, 15],
}
grid = GridSearchCV(
DecisionTreeClassifier(random_state=42),
param_grid, cv=5, scoring='accuracy',
return_train_score=True
)
grid.fit(X_train, y_train)
print(f"Best parameters: {grid.best_params_}")
print(f"Best CV accuracy: {grid.best_score_:.3f}")
print(f"Test accuracy: {grid.score(X_test, y_test):.3f}")
Same workflow from episode #16. fit/predict/score -- the consistent sklearn API means switching from LogisticRegression to DecisionTreeClassifier is literally changing one line. Everything else -- the pipeline, the cross-validation, the grid search -- stays the same.
Decision trees for regression
Everything we've built so far has been about classification -- predicting categories. But trees also work for continuous targets. The only difference is the splitting criterion: in stead of minimizing Gini impurity (which measures class mix), we minimize variance within each leaf.
The intuition is the same. A "pure" regression leaf is one where all samples have similar target values (low variance). The best split is the one that creates children with the lowest weighted average variance.
from sklearn.tree import DecisionTreeRegressor
# Regression: predict actual price (continuous)
price_continuous = 2500 * sqm + 500 * floor - 200 * age + np.random.randn(n) * 20000
X_train_r, X_test_r, y_train_r, y_test_r = train_test_split(
X, price_continuous, test_size=0.2, random_state=42
)
# Compare different depths
print("Regression tree -- effect of depth:\n")
print(f"{'depth':>10s} {'Train RMSE':>12s} {'Test RMSE':>12s} {'Test R-sq':>10s}")
print("-" * 48)
for depth in [2, 3, 4, 5, 8, None]:
reg = DecisionTreeRegressor(max_depth=depth, random_state=42)
reg.fit(X_train_r, y_train_r)
train_pred = reg.predict(X_train_r)
test_pred = reg.predict(X_test_r)
train_rmse = np.sqrt(np.mean((y_train_r - train_pred) ** 2))
test_rmse = np.sqrt(np.mean((y_test_r - test_pred) ** 2))
r2 = reg.score(X_test_r, y_test_r)
label = str(depth) if depth else "unlimited"
print(f"{label:>10s} {train_rmse:>12,.0f} {test_rmse:>12,.0f} {r2:>10.4f}")
Same overfitting pattern as classification. The unlimited tree gets near-zero training RMSE (memorization) but mediocre test RMSE. The pruned trees find the sweet spot. Notice how the regression tree's RMSE compares to the linear regression we built in episodes #10-11 -- on this particular dataset (where the true relationship IS mostly linear), the linear model might actually win. That's expected and important to understand.
Feature importances: what the tree learned
One of the most practical outputs of a decision tree is feature importances -- a measure of how much each feature contributed to the model's decisions. Features used in early splits (near the root) or used frequently across many splits are more important than features used in deep leaves or not used at all.
# Train a tree and inspect feature importances
clf_imp = DecisionTreeClassifier(max_depth=5, random_state=42)
clf_imp.fit(X_train, y_train)
print("Feature importances:\n")
importances = clf_imp.feature_importances_
for name, imp in sorted(zip(feature_names, importances),
key=lambda x: -x[1]):
bar = "#" * int(imp * 40)
print(f" {name:>6s}: {imp:.3f} {bar}")
These importances tell you which features the tree relied on most. For our apartment data, sqm should dominate because it's the strongest predictor. This is related to the permutation importance technique we built in episode #15 -- the tree's built-in importances are faster to compute but can be misleading for correlated features. We'll see more robust importance measures as we progress through this series.
A practical use: if you have 50 features and the tree says only 5 have non-zero importance, you now have a short list of the features that actually matter. That's automated feature selection -- the tree tells you which features to keep. Compare this to the manual correlation analysis and model-based selection from episode #15. Trees give you this for free as a side effect of training.
The full sklearn workflow with Pipelines
Let me bring everything together into the Pipeline workflow we established in episode #16. Even though trees don't NEED feature scaling (they split on thresholds, not on magnitudes), you might still want preprocessing in your pipeline for other reasons -- imputation, encoding, or to compare trees against linear models in the same pipeline structure.
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report
# A complete tree workflow
pipe = Pipeline([
('model', DecisionTreeClassifier(random_state=42)),
])
# Grid search for the best tree
param_grid = {
'model__max_depth': [3, 4, 5, 6, 8],
'model__min_samples_leaf': [1, 3, 5, 10],
'model__criterion': ['gini', 'entropy'],
}
search = GridSearchCV(pipe, param_grid, cv=5, scoring='f1')
search.fit(X_train, y_train)
print(f"Best params: {search.best_params_}")
print(f"Best CV F1: {search.best_score_:.3f}")
print(f"\n--- Test Set Results ---\n")
y_pred = search.predict(X_test)
print(classification_report(
y_test, y_pred,
target_names=['low price', 'high price']
))
Notice I used scoring='f1' in stead of accuracy. Remember episode #13 -- accuracy can be misleading, especially with imbalanced classes. F1 balances precision and recall. The tools from episode #13, the pipeline from #16, the grid search from #16 -- they all carry forward. You don't restart when you learn a new algorithm. You plug it into the same framework.
Trees vs linear models: when each wins
Decision trees and linear models have complementary strengths. Understanding when to use which is an important practical skill that goes beyond just knowing how each algorithm works.
Trees win when:
- The relationship between features and target is nonlinear and discontinuous. "If income > 50k AND age < 30, then buy" -- that's an if/else rule, not a line. Trees capture this naturally.
- Features interact in complex ways. Remember from episode #15 how we had to manually create
floor * (1 - has_elevator)as an interaction feature? Trees discover interactions automatically -- a split on floor in a subtree that already split on elevator IS that interaction. - You need interpretable rules. Doctors, loan officers, and regulators want to see the logic, not a vector of 200 weights.
- Data has outliers. Trees split on thresholds, so a value of 999,999 in a feature just goes to the same branch as any value above the split point. No weight gets pulled toward the outlier like in linear regression.
- No feature scaling needed. Trees don't care about units or magnitudes. Square meters in the range [30, 150] and floor in the range [0, 9] work fine together without StandardScaler.
Linear models win when:
- The true relationship is approximately linear. If price really does increase by EUR 2,500 per square meter, a linear model captures that exactly in one weight. A tree approximates it with a staircase of splits.
- You have many features relative to samples. With 500 features and 200 samples, a tree can easily overfit by finding spurious splits. A regularized linear model (Ridge from episode #11) is more stable.
- You need probabilistic outputs. Logistic regression gives you calibrated probabilities ("this patient has a 73% chance of high risk"). Trees give you the proportion of training samples in the leaf, which is a rougher probability estimate.
- The true pattern is smooth. Trees create rectangular decision boundaries -- step functions. If the real boundary is a smooth curve, the tree needs many splits to approximate it and can look jagged.
- Training speed matters for simple problems. A linear model solves in one matrix operation (the normal equation from episode #11). Building a tree requires trying every feature at every threshold at every node.
Here's a concrete comparison:
from sklearn.linear_model import LogisticRegression
# Compare on our dataset
models = {
"Logistic Regression": LogisticRegression(max_iter=1000),
"Decision Tree (d=3)": DecisionTreeClassifier(max_depth=3, random_state=42),
"Decision Tree (d=5)": DecisionTreeClassifier(max_depth=5, random_state=42),
"Decision Tree (tuned)": search.best_estimator_,
}
print(f"{'Model':>25s} {'CV Accuracy':>12s} {'CV F1':>10s}")
print("-" * 51)
for name, model in models.items():
acc_scores = cross_val_score(model, X, high_price, cv=5, scoring='accuracy')
f1_scores = cross_val_score(model, X, high_price, cv=5, scoring='f1')
print(f"{name:>25s} {acc_scores.mean():>12.3f} {f1_scores.mean():>10.3f}")
For THIS dataset (where the true relationship is roughly linear), the logistic regression and the decision tree should be reasonably close. The tree might even be slightly worse -- it's approximating a linear boundary with rectangular splits. On a dataset with genuine nonlinear patterns (like the floor-elevator interaction from episode #15), the tree would crush the linear model without needing manually engineered features.
The good news? You don't have to choose upfront. Cross-validate both. Use the pipeline framework from #16. The winner depends on the data, and the tools we've built tell you which one wins on YOUR data.
Why a single tree is just the beginning
I want to be upfront about something. Single decision trees are instructive -- you've just learned how they work from scratch, and that understanding is crucial. But in practice, a single tree is rarely the best model for any task. They're too sensitive to the specific training data. Change a few samples and the tree can look completely different (different split features, different thresholds, different structure). This instability is their achilles heel.
The solution? Don't use one tree. Use HUNDREDS of trees. Build them slightly differently from each other, then combine their predictions. One tree might be wrong about a particular sample, but if 500 trees vote on it, the majority is usually right. This idea -- combining many weak models into one strong model -- is one of the most powerful concepts in all of machine learning. And the algorithms that do this (there are several, and they're coming up next in this series) consistently rank among the best performing models on structured/tabular data.
Understanding the single tree is the foundation. Everything that comes next builds directly on what we covered today.
Let's recap
We left the world of linear models today and built something fundamentally different. Here's what we covered:
- Decision trees ask yes/no questions about features, recursively splitting data into increasingly pure groups. No gradient descent, no weighted sums, no feature scaling required;
- Gini impurity measures how mixed a node is: 0 = pure (one class), 0.5 = maximally impure (50/50 for binary). Trees split to minimize Gini, or equivalently, to maximize information gain;
- We built a complete decision tree classifier from scratch --
find_best_split,build_tree,predict-- using pure NumPy. The recursive structure finds the optimal split at each node in a single pass through the data; - Trees are interpretable: you can print the learned logic as a set of human-readable if/else rules. This is a major advantage in regulated domains and for debugging;
- Unlimited trees memorize training data (100% training accuracy, poor test accuracy). Pruning parameters --
max_depth,min_samples_leaf,min_samples_split,max_leaf_nodes-- control complexity and prevent overfitting; - Regression trees work the same way but minimize variance in stead of Gini impurity. The overfitting pattern is identical;
- Feature importances come free with tree training -- the tree tells you which features it relied on most, which is automated feature selection;
- Trees discover interactions automatically (a split on floor after splitting on elevator IS the interaction feature we created manually in episode #15);
- Trees beat linear models on nonlinear/discontinuous data, but linear models win on smooth linear relationships. Cross-validate both using the tools from episodes #13 and #16;
- A single tree is unstable and rarely optimal alone. The real power comes from combining many trees -- an idea we'll explore next.