Decision Trees: How Machines Make Decisions Like Humans
Have you ever played 20 Questions? That game where you guess what someone is thinking by narrowed-down, yes-or-no questions? If so, you already understand the core logic of Decision Trees. These elegant algorithms are among the most intuitive in machine learning because they mirror the way we naturally break down complex choices into simpler, sequential steps.
Why Decision Trees Feel So Natural
When you decide whether to go for a run, your brain might follow a path like this: "Is it raining? If yes, stay home. If no, do I have time? If yes, go running. If no, maybe later." That's exactly how a decision tree works—it asks a series of questions, each answer leading to the next question until reaching a final decision.
This human-like reasoning is why decision trees are so popular, especially when explainability matters. Try explaining to a doctor why your model predicted a certain diagnosis. With a neural network, good luck. With a decision tree? You can literally show them the path of questions that led to the conclusion.
How Decision Trees Work
A decision tree is a flowchart-like structure where:
- Internal nodes represent tests on features (e.g., "Is age > 30?")
- Branches represent the outcomes of those tests
- Leaf nodes represent the final predictions (class labels or values)
Imagine we're building a classifier to predict whether someone will buy a product based on age and income:
[Age > 35?]
/ \
Yes No
/ \
[Income > 50K?] [Will NOT Buy]
/ \
Yes No
/ \
[Will Buy] [Will NOT Buy]The tree learns these questions from data, finding the optimal splits that best separate different classes.
Splitting Criteria: The Heart of Decision Trees
How does a decision tree know which question to ask first? It needs a way to measure how "good" a split is. There are two main approaches:
Entropy and Information Gain
Entropy measures the disorder or impurity in a set. In information theory terms, it quantifies the uncertainty in a random variable.
For a dataset with classes, entropy is calculated as:
$$H(S) = -\sum_{i=1}^{c} p_i \log_2(p_i)$$
Where $p_i$ is the proportion of samples belonging to class $i$.
- Entropy = 0: All samples belong to the same class (pure node)
- Entropy = 1 (for binary): Samples are evenly split between classes (maximum impurity)
Information Gain measures how much entropy we reduce after a split:
$$IG(S, A) = H(S) - \sum_{v \in Values(A)} \frac{|S_v|}{|S|} H(S_v)$$
The feature with the highest information gain becomes the splitting criterion.
Gini Impurity
Gini impurity is another measure of how often a randomly chosen element would be incorrectly labeled:
$$Gini(S) = 1 - \sum_{i=1}^{c} p_i^2$$
- Gini = 0: Pure node
- Gini = 0.5 (for binary): Maximum impurity
In practice, both metrics usually produce similar trees. Gini is computationally slightly faster since it doesn't require logarithm calculations.
Building a Tree Step by Step
Let's walk through building a decision tree with a concrete example. Suppose we have data about whether customers churned based on their contract type and tenure:
| Contract | Tenure (months) | Churned |
|---|---|---|
| Monthly | 2 | Yes |
| Monthly | 8 | Yes |
| Annual | 12 | No |
| Monthly | 15 | No |
| Annual | 6 | No |
| Monthly | 3 | Yes |
Step 1: Calculate the initial entropy
We have 3 "Yes" and 3 "No", so:
$$H(S) = -\frac{3}{6}\log_2(\frac{3}{6}) - \frac{3}{6}\log_2(\frac{3}{6}) = 1.0$$
Step 2: Calculate information gain for each feature
For "Contract":
- Monthly: 3 Yes, 1 No → $H = 0.811$
- Annual: 0 Yes, 2 No → $H = 0$
$$IG(Contract) = 1.0 - (\frac{4}{6} \times 0.811 + \frac{2}{6} \times 0) = 0.459$$
For "Tenure > 10":
- Yes: 1 Yes, 2 No → $H = 0.918$
- No: 2 Yes, 1 No → $H = 0.918$
$$IG(Tenure > 10) = 1.0 - (\frac{3}{6} \times 0.918 + \frac{3}{6} \times 0.918) = 0.082$$
Step 3: Choose the best split
"Contract" has higher information gain, so it becomes the root. We then repeat recursively for each branch until we reach pure nodes or a stopping criterion.
Tree Pruning: Preventing Overfitting
Decision trees have a tendency to grow too complex, perfectly fitting the training data but failing on new data. This is called overfitting. Imagine a tree that creates a unique path for every single training example—perfect accuracy on training data, terrible on anything else.
Pre-pruning (Early Stopping)
Stop growing the tree before it becomes too complex:
- max_depth: Limit how deep the tree can grow
- min_samples_split: Minimum samples required to split a node
- min_samples_leaf: Minimum samples required in a leaf node
Post-pruning
Grow the full tree first, then remove branches that don't improve validation performance:
- Cost-complexity pruning: Add a penalty term for tree complexity
- Reduced error pruning: Remove subtrees that don't increase validation error
Python Implementation with Scikit-learn
Let's build a complete, runnable example using the famous Iris dataset:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score, classification_report
# Load the Iris dataset
iris = load_iris()
X, y = iris.data, iris.target
feature_names = iris.feature_names
class_names = iris.target_names
# Split into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
# Create and train the decision tree
clf = DecisionTreeClassifier(
criterion='gini', # or 'entropy' for information gain
max_depth=3, # limit depth to prevent overfitting
min_samples_split=5, # minimum samples to split a node
min_samples_leaf=2, # minimum samples in a leaf
random_state=42
)
clf.fit(X_train, y_train)
# Make predictions
y_pred = clf.predict(X_test)
# Evaluate the model
print("Accuracy:", accuracy_score(y_test, y_pred))
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=class_names))Output:
Accuracy: 0.9777777777777777
Classification Report:
precision recall f1-score support
setosa 1.00 1.00 1.00 19
versicolor 1.00 0.93 0.96 13
virginica 0.92 1.00 0.96 13
accuracy 0.98 45
macro avg 0.97 0.98 0.97 45
weighted avg 0.98 0.98 0.98 45Visualizing the Tree
One of the best things about decision trees is that you can actually see the decision-making process:
# Create a figure with a larger size for better readability
plt.figure(figsize=(20, 10))
# Plot the decision tree
plot_tree(
clf,
feature_names=feature_names,
class_names=class_names,
filled=True, # color nodes by class
rounded=True, # rounded node boxes
fontsize=12,
proportion=True # show proportion of samples
)
plt.title("Decision Tree for Iris Classification", fontsize=16)
plt.tight_layout()
plt.savefig('decision_tree_iris.png', dpi=150, bbox_inches='tight')
plt.show()
# You can also export to text format
from sklearn.tree import export_text
tree_rules = export_text(clf, feature_names=feature_names)
print("Decision Tree Rules:")
print(tree_rules)Output:
Decision Tree Rules:
|--- petal width (cm) <= 0.80
| |--- class: setosa
|--- petal width (cm) > 0.80
| |--- petal width (cm) <= 1.75
| | |--- petal length (cm) <= 4.95
| | | |--- class: versicolor
| | |--- petal length (cm) > 4.95
| | | |--- class: virginica
| |--- petal width (cm) > 1.75
| | |--- class: virginicaThis visualization shows exactly how the tree makes decisions—something that's nearly impossible with black-box models.
Advantages and Disadvantages
Advantages
- Interpretability: Easy to understand and explain, even to non-technical stakeholders
- No feature scaling required: Trees don't care about the scale of your features
- Handles both numerical and categorical data: Very flexible input types
- Non-linear relationships: Can capture complex decision boundaries
- Feature importance: Built-in measure of which features matter most
- Fast prediction: O(log n) prediction time for balanced trees
Disadvantages
- Overfitting: Prone to creating overly complex trees
- Instability: Small changes in data can lead to very different trees
- Biased toward features with more levels: Features with many categories can dominate
- Greedy algorithm: Locally optimal splits may not be globally optimal
- Poor extrapolation: Can't predict values outside the training range (for regression)
When to Use Decision Trees
Decision trees shine in these scenarios:
- Explainability is crucial: Healthcare, finance, legal applications where you need to justify decisions
- Quick baseline model: Get a working model fast before trying more complex approaches
- Feature selection: Use feature importance to identify relevant variables
- Mixed data types: When you have both categorical and numerical features
- Non-linear relationships: When linear models aren't capturing the patterns
Consider alternatives when:
- High accuracy is the priority: Random Forests or Gradient Boosting often perform better
- Data is very high-dimensional: Trees can struggle with hundreds of features
- Smooth decision boundaries needed: Trees create rectangular decision regions
Conclusion
Decision trees are a fundamental building block in machine learning, combining simplicity with power. Their intuitive nature makes them perfect for learning and for situations where you need to explain your model's reasoning.
While a single decision tree might not win Kaggle competitions, understanding how they work is essential—especially since ensemble methods like Random Forests and Gradient Boosting are simply collections of decision trees working together.
Start with decision trees when tackling a new problem. They'll give you insights into your data, serve as a strong baseline, and help you understand what more complex models are doing under the hood.
In future posts, we'll explore Random Forests and how combining multiple trees creates models that are both accurate and robust. Stay tuned!