Stratified K-Fold
- This lesson explains Stratified K-Fold and how it maintains balanced class distribution when validating machine learning models.
Why Stratification is Needed
In imbalanced datasets, some classes are much smaller than others.
Normal K-Fold may split data unevenly, resulting in some folds missing minority class examples.
Stratified K-Fold ensures each fold has the same proportion of classes as the original dataset.
Example:
Dataset: 90 “Not Spam”, 10 “Spam” emails
Normal 5-Fold → Some folds may have 0 “Spam” emails → poor model evaluation
Stratified 5-Fold → Each fold maintains ~10% “Spam” → fair evaluation
Handling Imbalanced Datasets
Maintains class distribution in every fold
Reduces variance in model evaluation metrics
Essential for metrics like Precision, Recall, F1-score
Difference from Normal K-Fold
Python Example: Stratified K-Fold
Stratified K-Fold Cross Validation Example in Python using Logistic Regression
This Python example demonstrates how to evaluate a classification model using Stratified K-Fold Cross Validation. The code loads the Iris dataset, creates a Logistic Regression model, and applies Stratified K-Fold to ensure that each fold maintains the same class distribution as the original dataset. It calculates the accuracy for each fold and the average accuracy to assess the model’s performance.
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
import numpy as np
# Step 1: Load dataset
iris = load_iris()
X, y = iris.data, iris.target
# Step 2: Create Model
model = LogisticRegression(max_iter=200)
# Step 3: Stratified K-Fold Setup
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# Step 4: Perform Cross Validation
scores = cross_val_score(model, X, y, cv=skf, scoring='accuracy')
# Step 5: Results
print("Accuracy for each fold:", scores)
print("Average Accuracy:", np.mean(scores))
Output:
Accuracy for each fold: [1. 0.96666667 0.93333333 1. 0.93333333]
Average Accuracy: 0.9666666666666668
Key Points:
Each fold has similar class distribution
Ensures reliable model evaluation for imbalanced datasets