Data Drift Detection Intermediate
Data drift is the #1 cause of silent ML failures in production. When the distribution of incoming data changes from what your model was trained on, predictions degrade — often without any error signals. This lesson covers the statistical tests, production code, and alerting strategies you need to catch drift before it impacts users.
Types of Data Drift
| Drift Type | What Changes | Example | Detection Method |
|---|---|---|---|
| Feature Drift (Covariate Shift) | Input feature distributions P(X) | Average user age shifts from 35 to 55 | PSI, KS test per feature |
| Label Drift (Prior Probability Shift) | Target distribution P(Y) | Fraud rate jumps from 1% to 5% | Prediction distribution monitoring |
| Concept Drift | Relationship P(Y|X) | Same features, different correct label (e.g., COVID changed buying patterns) | Performance metrics with ground truth |
| Upstream Data Change | Schema, format, or semantics | Partner API renames a field, returns nulls | Schema validation, null rate monitoring |
Statistical Tests for Drift Detection
Population Stability Index (PSI)
PSI is the most widely used drift metric in industry. It compares two distributions by binning values and measuring divergence. Works for both numerical and categorical features.
# Production PSI Calculator
import numpy as np
from typing import List, Tuple, Dict
class PSICalculator:
"""Population Stability Index for drift detection.
Used by banks, insurance, and ML teams worldwide.
PSI interpretation:
PSI < 0.1 : No significant drift
0.1-0.2 : Moderate drift (investigate)
PSI > 0.2 : Significant drift (action required)
"""
def __init__(self, n_bins: int = 10, epsilon: float = 1e-4):
self.n_bins = n_bins
self.epsilon = epsilon # Avoid log(0)
def calculate_psi(self, reference: np.ndarray,
current: np.ndarray) -> dict:
"""Calculate PSI between reference (training) and current
(production) distributions."""
# Create bins from reference distribution
breakpoints = np.percentile(
reference,
np.linspace(0, 100, self.n_bins + 1)
)
breakpoints[0] = -np.inf
breakpoints[-1] = np.inf
# Calculate proportions in each bin
ref_counts = np.histogram(reference, bins=breakpoints)[0]
cur_counts = np.histogram(current, bins=breakpoints)[0]
ref_pct = ref_counts / len(reference) + self.epsilon
cur_pct = cur_counts / len(current) + self.epsilon
# PSI formula: sum((cur - ref) * ln(cur / ref))
psi_values = (cur_pct - ref_pct) * np.log(cur_pct / ref_pct)
psi_total = np.sum(psi_values)
return {
"psi": round(psi_total, 4),
"status": self._interpret(psi_total),
"bin_contributions": [round(v, 4) for v in psi_values],
"reference_size": len(reference),
"current_size": len(current)
}
def _interpret(self, psi: float) -> str:
if psi < 0.1:
return "no_drift"
elif psi < 0.2:
return "moderate_drift"
else:
return "significant_drift"
# Example usage
np.random.seed(42)
ref_data = np.random.normal(50, 10, 10000) # Training distribution
cur_data = np.random.normal(53, 12, 5000) # Production (shifted)
psi_calc = PSICalculator()
result = psi_calc.calculate_psi(ref_data, cur_data)
print(f"PSI: {result['psi']} - Status: {result['status']}")
# Output: PSI: 0.0891 - Status: no_drift (borderline)
Kolmogorov-Smirnov Test
KS test measures the maximum difference between two cumulative distribution functions. It's distribution-free and works well for continuous features.
# KS Test for continuous feature drift
from scipy import stats
import numpy as np
class KSDriftDetector:
"""Kolmogorov-Smirnov test for feature drift.
Best for: continuous numerical features.
Returns p-value: low p-value = distributions are different.
Typical threshold: p < 0.05 indicates drift.
"""
def __init__(self, p_value_threshold: float = 0.05):
self.p_value_threshold = p_value_threshold
def test_drift(self, reference: np.ndarray,
current: np.ndarray) -> dict:
"""Run KS test between reference and current data."""
statistic, p_value = stats.ks_2samp(reference, current)
return {
"test": "kolmogorov_smirnov",
"statistic": round(statistic, 4),
"p_value": round(p_value, 6),
"drift_detected": p_value < self.p_value_threshold,
"reference_mean": round(np.mean(reference), 4),
"current_mean": round(np.mean(current), 4),
"mean_shift": round(np.mean(current) - np.mean(reference), 4)
}
def test_all_features(self, reference_df, current_df,
feature_columns: list) -> dict:
"""Test drift across all features. Returns sorted by severity."""
results = {}
for col in feature_columns:
if col in reference_df.columns and col in current_df.columns:
ref = reference_df[col].dropna().values
cur = current_df[col].dropna().values
if len(ref) > 0 and len(cur) > 0:
results[col] = self.test_drift(ref, cur)
# Sort by p-value (lowest = most drifted)
sorted_results = dict(
sorted(results.items(),
key=lambda x: x[1]["p_value"])
)
drifted = [k for k, v in sorted_results.items()
if v["drift_detected"]]
return {
"total_features": len(results),
"drifted_features": len(drifted),
"drifted_names": drifted,
"details": sorted_results
}
Chi-Squared Test for Categorical Features
# Chi-squared test for categorical feature drift
from scipy.stats import chi2_contingency
from collections import Counter
import numpy as np
class CategoricalDriftDetector:
"""Chi-squared test for categorical feature drift.
Best for: categorical/string features.
"""
def __init__(self, p_value_threshold: float = 0.05):
self.p_value_threshold = p_value_threshold
def test_drift(self, reference: list, current: list) -> dict:
"""Run chi-squared test between two categorical distributions."""
# Get all categories
all_categories = sorted(set(reference) | set(current))
ref_counts = Counter(reference)
cur_counts = Counter(current)
# Build contingency table
observed = np.array([
[ref_counts.get(cat, 0) for cat in all_categories],
[cur_counts.get(cat, 0) for cat in all_categories]
])
# Remove columns with all zeros
non_zero = observed.sum(axis=0) > 0
observed = observed[:, non_zero]
filtered_cats = [c for c, nz in zip(all_categories, non_zero) if nz]
if observed.shape[1] < 2:
return {"test": "chi_squared", "status": "insufficient_categories"}
chi2, p_value, dof, expected = chi2_contingency(observed)
# Find which categories shifted most
category_shifts = {}
for i, cat in enumerate(filtered_cats):
ref_pct = observed[0][i] / observed[0].sum()
cur_pct = observed[1][i] / observed[1].sum()
category_shifts[cat] = {
"reference_pct": round(ref_pct, 4),
"current_pct": round(cur_pct, 4),
"shift": round(cur_pct - ref_pct, 4)
}
return {
"test": "chi_squared",
"chi2_statistic": round(chi2, 4),
"p_value": round(p_value, 6),
"degrees_of_freedom": dof,
"drift_detected": p_value < self.p_value_threshold,
"category_shifts": category_shifts
}
# Example: detect drift in browser type feature
ref_browsers = ["chrome"]*500 + ["safari"]*300 + ["firefox"]*150 + ["edge"]*50
cur_browsers = ["chrome"]*400 + ["safari"]*200 + ["firefox"]*100 + ["edge"]*250 + ["brave"]*50
detector = CategoricalDriftDetector()
result = detector.test_drift(ref_browsers, cur_browsers)
print(f"Drift detected: {result['drift_detected']}")
print(f"p-value: {result['p_value']}")
Production Drift Monitoring Pipeline
# Complete drift monitoring pipeline for production
import json
import time
from datetime import datetime
from typing import Dict, List, Optional
class DriftMonitoringPipeline:
"""End-to-end drift monitoring for a deployed model.
Usage:
pipeline = DriftMonitoringPipeline(
model_name="fraud-detector",
reference_path="s3://models/fraud/v2/reference_data.parquet"
)
pipeline.add_numerical_feature("transaction_amount", threshold=0.15)
pipeline.add_numerical_feature("user_age", threshold=0.2)
pipeline.add_categorical_feature("merchant_category")
# Run on a schedule (e.g., every hour via Airflow/cron)
results = pipeline.run(current_data_df)
pipeline.send_alerts(results)
"""
def __init__(self, model_name: str, reference_path: str):
self.model_name = model_name
self.reference_path = reference_path
self.numerical_features: Dict[str, float] = {}
self.categorical_features: List[str] = []
self.psi_calc = PSICalculator()
self.ks_detector = KSDriftDetector()
self.cat_detector = CategoricalDriftDetector()
def add_numerical_feature(self, name: str, threshold: float = 0.2):
self.numerical_features[name] = threshold
def add_categorical_feature(self, name: str):
self.categorical_features.append(name)
def run(self, current_data, reference_data) -> dict:
"""Run all drift checks. Returns structured results."""
timestamp = datetime.utcnow().isoformat()
results = {
"model": self.model_name,
"timestamp": timestamp,
"total_features_checked": 0,
"drifted_features": [],
"overall_status": "ok",
"details": {}
}
# Check numerical features with PSI + KS
for feature, psi_threshold in self.numerical_features.items():
if feature not in current_data.columns:
results["details"][feature] = {"status": "missing_feature"}
results["drifted_features"].append(feature)
continue
ref = reference_data[feature].dropna().values
cur = current_data[feature].dropna().values
psi_result = self.psi_calc.calculate_psi(ref, cur)
ks_result = self.ks_detector.test_drift(ref, cur)
drift_detected = (psi_result["psi"] > psi_threshold or
ks_result["drift_detected"])
results["details"][feature] = {
"type": "numerical",
"psi": psi_result,
"ks_test": ks_result,
"drift_detected": drift_detected
}
if drift_detected:
results["drifted_features"].append(feature)
results["total_features_checked"] += 1
# Check categorical features with chi-squared
for feature in self.categorical_features:
if feature not in current_data.columns:
results["details"][feature] = {"status": "missing_feature"}
results["drifted_features"].append(feature)
continue
ref = reference_data[feature].dropna().tolist()
cur = current_data[feature].dropna().tolist()
chi2_result = self.cat_detector.test_drift(ref, cur)
results["details"][feature] = {
"type": "categorical",
"chi_squared": chi2_result,
"drift_detected": chi2_result.get("drift_detected", False)
}
if chi2_result.get("drift_detected", False):
results["drifted_features"].append(feature)
results["total_features_checked"] += 1
# Set overall status
n_drifted = len(results["drifted_features"])
if n_drifted == 0:
results["overall_status"] = "ok"
elif n_drifted <= 2:
results["overall_status"] = "warning"
else:
results["overall_status"] = "critical"
return results
def send_alerts(self, results: dict):
"""Send alerts based on drift results.
In production, integrate with Slack, PagerDuty, etc."""
if results["overall_status"] == "ok":
return # No alert needed
alert = {
"title": f"Data Drift Alert: {self.model_name}",
"severity": results["overall_status"],
"drifted_features": results["drifted_features"],
"timestamp": results["timestamp"],
"action": "Review drift dashboard and check upstream data"
}
# In production: send to Slack webhook, PagerDuty, etc.
print(f"ALERT [{alert['severity'].upper()}]: {json.dumps(alert, indent=2)}")
Alerting Thresholds: Practical Guidelines
| Metric | Warning | Critical | Notes |
|---|---|---|---|
| PSI (per feature) | > 0.1 | > 0.25 | Industry standard from credit risk modeling |
| KS statistic | > 0.1 | > 0.2 | Or use p-value < 0.01 |
| Null rate change | 3x baseline | 10x baseline | Often catches upstream API changes |
| # features drifted | > 10% of features | > 25% of features | Multiple features drifting = systemic issue |
| New category rate | > 1% of values | > 5% of values | Unknown categories in categorical features |
Pro Tip: Windowed Drift Detection
Don't compare against just the training data. Maintain a rolling reference window (e.g., last 7 days) alongside the training baseline. This helps distinguish seasonal patterns from true drift. If both the training baseline AND the rolling window show drift, it's likely a real issue.
Don't compare against just the training data. Maintain a rolling reference window (e.g., last 7 days) alongside the training baseline. This helps distinguish seasonal patterns from true drift. If both the training baseline AND the rolling window show drift, it's likely a real issue.
Common Mistake: Too Many Alerts
If you monitor 100 features with p-value < 0.05, you'll get ~5 false positives per check by random chance. Use Bonferroni correction (divide threshold by number of features) or focus alerts on the top-N most important features rather than alerting on every feature independently.
If you monitor 100 features with p-value < 0.05, you'll get ~5 false positives per check by random chance. Use Bonferroni correction (divide threshold by number of features) or focus alerts on the top-N most important features rather than alerting on every feature independently.
Lilly Tech Systems