Stratified K-Fold

  • This lesson explains Stratified K-Fold and how it maintains balanced class distribution when validating machine learning models.
  • Why Stratification is Needed

    • In imbalanced datasets, some classes are much smaller than others.

    • Normal K-Fold may split data unevenly, resulting in some folds missing minority class examples.

    • Stratified K-Fold ensures each fold has the same proportion of classes as the original dataset.

    Example:

    • Dataset: 90 “Not Spam”, 10 “Spam” emails

    • Normal 5-Fold → Some folds may have 0 “Spam” emails → poor model evaluation

    • Stratified 5-Fold → Each fold maintains ~10% “Spam” → fair evaluation


    Handling Imbalanced Datasets

    • Maintains class distribution in every fold

    • Reduces variance in model evaluation metrics

    • Essential for metrics like Precision, Recall, F1-score


    Difference from Normal K-Fold

    Aspect

    Normal K-Fold

    Stratified K-Fold

    Data Split

    Randomly into K folds

    Split while preserving class proportion

    Use Case

    Any dataset, mainly regression

    Imbalanced classification datasets

    Risk

    Minority classes may be missing in folds

    Class distribution maintained

    Metric Reliability

    Can vary widely

    More stable and reliable metrics

    Python Example: Stratified K-Fold

Stratified K-Fold Cross Validation Example in Python using Logistic Regression

This Python example demonstrates how to evaluate a classification model using Stratified K-Fold Cross Validation. The code loads the Iris dataset, creates a Logistic Regression model, and applies Stratified K-Fold to ensure that each fold maintains the same class distribution as the original dataset. It calculates the accuracy for each fold and the average accuracy to assess the model’s performance.

from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
import numpy as np

# Step 1: Load dataset
iris = load_iris()
X, y = iris.data, iris.target

# Step 2: Create Model
model = LogisticRegression(max_iter=200)

# Step 3: Stratified K-Fold Setup
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Step 4: Perform Cross Validation
scores = cross_val_score(model, X, y, cv=skf, scoring='accuracy')

# Step 5: Results
print("Accuracy for each fold:", scores)
print("Average Accuracy:", np.mean(scores))
  • Output:

    Accuracy for each fold: [1.         0.96666667 0.93333333 1.         0.93333333]

    Average Accuracy: 0.9666666666666668


    Key Points:

    • Each fold has similar class distribution

    • Ensures reliable model evaluation for imbalanced datasets