Data Drift Detection Intermediate

Data drift is the #1 cause of silent ML failures in production. When the distribution of incoming data changes from what your model was trained on, predictions degrade — often without any error signals. This lesson covers the statistical tests, production code, and alerting strategies you need to catch drift before it impacts users.

Types of Data Drift

Drift TypeWhat ChangesExampleDetection Method
Feature Drift (Covariate Shift)Input feature distributions P(X)Average user age shifts from 35 to 55PSI, KS test per feature
Label Drift (Prior Probability Shift)Target distribution P(Y)Fraud rate jumps from 1% to 5%Prediction distribution monitoring
Concept DriftRelationship P(Y|X)Same features, different correct label (e.g., COVID changed buying patterns)Performance metrics with ground truth
Upstream Data ChangeSchema, format, or semanticsPartner API renames a field, returns nullsSchema validation, null rate monitoring

Statistical Tests for Drift Detection

Population Stability Index (PSI)

PSI is the most widely used drift metric in industry. It compares two distributions by binning values and measuring divergence. Works for both numerical and categorical features.

# Production PSI Calculator
import numpy as np
from typing import List, Tuple, Dict

class PSICalculator:
    """Population Stability Index for drift detection.
    Used by banks, insurance, and ML teams worldwide.

    PSI interpretation:
      PSI < 0.1  : No significant drift
      0.1-0.2    : Moderate drift (investigate)
      PSI > 0.2  : Significant drift (action required)
    """

    def __init__(self, n_bins: int = 10, epsilon: float = 1e-4):
        self.n_bins = n_bins
        self.epsilon = epsilon  # Avoid log(0)

    def calculate_psi(self, reference: np.ndarray,
                      current: np.ndarray) -> dict:
        """Calculate PSI between reference (training) and current
        (production) distributions."""
        # Create bins from reference distribution
        breakpoints = np.percentile(
            reference,
            np.linspace(0, 100, self.n_bins + 1)
        )
        breakpoints[0] = -np.inf
        breakpoints[-1] = np.inf

        # Calculate proportions in each bin
        ref_counts = np.histogram(reference, bins=breakpoints)[0]
        cur_counts = np.histogram(current, bins=breakpoints)[0]

        ref_pct = ref_counts / len(reference) + self.epsilon
        cur_pct = cur_counts / len(current) + self.epsilon

        # PSI formula: sum((cur - ref) * ln(cur / ref))
        psi_values = (cur_pct - ref_pct) * np.log(cur_pct / ref_pct)
        psi_total = np.sum(psi_values)

        return {
            "psi": round(psi_total, 4),
            "status": self._interpret(psi_total),
            "bin_contributions": [round(v, 4) for v in psi_values],
            "reference_size": len(reference),
            "current_size": len(current)
        }

    def _interpret(self, psi: float) -> str:
        if psi < 0.1:
            return "no_drift"
        elif psi < 0.2:
            return "moderate_drift"
        else:
            return "significant_drift"


# Example usage
np.random.seed(42)
ref_data = np.random.normal(50, 10, 10000)   # Training distribution
cur_data = np.random.normal(53, 12, 5000)    # Production (shifted)

psi_calc = PSICalculator()
result = psi_calc.calculate_psi(ref_data, cur_data)
print(f"PSI: {result['psi']} - Status: {result['status']}")
# Output: PSI: 0.0891 - Status: no_drift (borderline)

Kolmogorov-Smirnov Test

KS test measures the maximum difference between two cumulative distribution functions. It's distribution-free and works well for continuous features.

# KS Test for continuous feature drift
from scipy import stats
import numpy as np

class KSDriftDetector:
    """Kolmogorov-Smirnov test for feature drift.
    Best for: continuous numerical features.

    Returns p-value: low p-value = distributions are different.
    Typical threshold: p < 0.05 indicates drift.
    """

    def __init__(self, p_value_threshold: float = 0.05):
        self.p_value_threshold = p_value_threshold

    def test_drift(self, reference: np.ndarray,
                   current: np.ndarray) -> dict:
        """Run KS test between reference and current data."""
        statistic, p_value = stats.ks_2samp(reference, current)

        return {
            "test": "kolmogorov_smirnov",
            "statistic": round(statistic, 4),
            "p_value": round(p_value, 6),
            "drift_detected": p_value < self.p_value_threshold,
            "reference_mean": round(np.mean(reference), 4),
            "current_mean": round(np.mean(current), 4),
            "mean_shift": round(np.mean(current) - np.mean(reference), 4)
        }

    def test_all_features(self, reference_df, current_df,
                           feature_columns: list) -> dict:
        """Test drift across all features. Returns sorted by severity."""
        results = {}
        for col in feature_columns:
            if col in reference_df.columns and col in current_df.columns:
                ref = reference_df[col].dropna().values
                cur = current_df[col].dropna().values
                if len(ref) > 0 and len(cur) > 0:
                    results[col] = self.test_drift(ref, cur)

        # Sort by p-value (lowest = most drifted)
        sorted_results = dict(
            sorted(results.items(),
                   key=lambda x: x[1]["p_value"])
        )

        drifted = [k for k, v in sorted_results.items()
                   if v["drift_detected"]]

        return {
            "total_features": len(results),
            "drifted_features": len(drifted),
            "drifted_names": drifted,
            "details": sorted_results
        }

Chi-Squared Test for Categorical Features

# Chi-squared test for categorical feature drift
from scipy.stats import chi2_contingency
from collections import Counter
import numpy as np

class CategoricalDriftDetector:
    """Chi-squared test for categorical feature drift.
    Best for: categorical/string features.
    """

    def __init__(self, p_value_threshold: float = 0.05):
        self.p_value_threshold = p_value_threshold

    def test_drift(self, reference: list, current: list) -> dict:
        """Run chi-squared test between two categorical distributions."""
        # Get all categories
        all_categories = sorted(set(reference) | set(current))

        ref_counts = Counter(reference)
        cur_counts = Counter(current)

        # Build contingency table
        observed = np.array([
            [ref_counts.get(cat, 0) for cat in all_categories],
            [cur_counts.get(cat, 0) for cat in all_categories]
        ])

        # Remove columns with all zeros
        non_zero = observed.sum(axis=0) > 0
        observed = observed[:, non_zero]
        filtered_cats = [c for c, nz in zip(all_categories, non_zero) if nz]

        if observed.shape[1] < 2:
            return {"test": "chi_squared", "status": "insufficient_categories"}

        chi2, p_value, dof, expected = chi2_contingency(observed)

        # Find which categories shifted most
        category_shifts = {}
        for i, cat in enumerate(filtered_cats):
            ref_pct = observed[0][i] / observed[0].sum()
            cur_pct = observed[1][i] / observed[1].sum()
            category_shifts[cat] = {
                "reference_pct": round(ref_pct, 4),
                "current_pct": round(cur_pct, 4),
                "shift": round(cur_pct - ref_pct, 4)
            }

        return {
            "test": "chi_squared",
            "chi2_statistic": round(chi2, 4),
            "p_value": round(p_value, 6),
            "degrees_of_freedom": dof,
            "drift_detected": p_value < self.p_value_threshold,
            "category_shifts": category_shifts
        }


# Example: detect drift in browser type feature
ref_browsers = ["chrome"]*500 + ["safari"]*300 + ["firefox"]*150 + ["edge"]*50
cur_browsers = ["chrome"]*400 + ["safari"]*200 + ["firefox"]*100 + ["edge"]*250 + ["brave"]*50

detector = CategoricalDriftDetector()
result = detector.test_drift(ref_browsers, cur_browsers)
print(f"Drift detected: {result['drift_detected']}")
print(f"p-value: {result['p_value']}")

Production Drift Monitoring Pipeline

# Complete drift monitoring pipeline for production
import json
import time
from datetime import datetime
from typing import Dict, List, Optional

class DriftMonitoringPipeline:
    """End-to-end drift monitoring for a deployed model.

    Usage:
        pipeline = DriftMonitoringPipeline(
            model_name="fraud-detector",
            reference_path="s3://models/fraud/v2/reference_data.parquet"
        )
        pipeline.add_numerical_feature("transaction_amount", threshold=0.15)
        pipeline.add_numerical_feature("user_age", threshold=0.2)
        pipeline.add_categorical_feature("merchant_category")

        # Run on a schedule (e.g., every hour via Airflow/cron)
        results = pipeline.run(current_data_df)
        pipeline.send_alerts(results)
    """

    def __init__(self, model_name: str, reference_path: str):
        self.model_name = model_name
        self.reference_path = reference_path
        self.numerical_features: Dict[str, float] = {}
        self.categorical_features: List[str] = []
        self.psi_calc = PSICalculator()
        self.ks_detector = KSDriftDetector()
        self.cat_detector = CategoricalDriftDetector()

    def add_numerical_feature(self, name: str, threshold: float = 0.2):
        self.numerical_features[name] = threshold

    def add_categorical_feature(self, name: str):
        self.categorical_features.append(name)

    def run(self, current_data, reference_data) -> dict:
        """Run all drift checks. Returns structured results."""
        timestamp = datetime.utcnow().isoformat()
        results = {
            "model": self.model_name,
            "timestamp": timestamp,
            "total_features_checked": 0,
            "drifted_features": [],
            "overall_status": "ok",
            "details": {}
        }

        # Check numerical features with PSI + KS
        for feature, psi_threshold in self.numerical_features.items():
            if feature not in current_data.columns:
                results["details"][feature] = {"status": "missing_feature"}
                results["drifted_features"].append(feature)
                continue

            ref = reference_data[feature].dropna().values
            cur = current_data[feature].dropna().values

            psi_result = self.psi_calc.calculate_psi(ref, cur)
            ks_result = self.ks_detector.test_drift(ref, cur)

            drift_detected = (psi_result["psi"] > psi_threshold or
                              ks_result["drift_detected"])

            results["details"][feature] = {
                "type": "numerical",
                "psi": psi_result,
                "ks_test": ks_result,
                "drift_detected": drift_detected
            }

            if drift_detected:
                results["drifted_features"].append(feature)

            results["total_features_checked"] += 1

        # Check categorical features with chi-squared
        for feature in self.categorical_features:
            if feature not in current_data.columns:
                results["details"][feature] = {"status": "missing_feature"}
                results["drifted_features"].append(feature)
                continue

            ref = reference_data[feature].dropna().tolist()
            cur = current_data[feature].dropna().tolist()

            chi2_result = self.cat_detector.test_drift(ref, cur)

            results["details"][feature] = {
                "type": "categorical",
                "chi_squared": chi2_result,
                "drift_detected": chi2_result.get("drift_detected", False)
            }

            if chi2_result.get("drift_detected", False):
                results["drifted_features"].append(feature)

            results["total_features_checked"] += 1

        # Set overall status
        n_drifted = len(results["drifted_features"])
        if n_drifted == 0:
            results["overall_status"] = "ok"
        elif n_drifted <= 2:
            results["overall_status"] = "warning"
        else:
            results["overall_status"] = "critical"

        return results

    def send_alerts(self, results: dict):
        """Send alerts based on drift results.
        In production, integrate with Slack, PagerDuty, etc."""
        if results["overall_status"] == "ok":
            return  # No alert needed

        alert = {
            "title": f"Data Drift Alert: {self.model_name}",
            "severity": results["overall_status"],
            "drifted_features": results["drifted_features"],
            "timestamp": results["timestamp"],
            "action": "Review drift dashboard and check upstream data"
        }
        # In production: send to Slack webhook, PagerDuty, etc.
        print(f"ALERT [{alert['severity'].upper()}]: {json.dumps(alert, indent=2)}")

Alerting Thresholds: Practical Guidelines

MetricWarningCriticalNotes
PSI (per feature)> 0.1> 0.25Industry standard from credit risk modeling
KS statistic> 0.1> 0.2Or use p-value < 0.01
Null rate change3x baseline10x baselineOften catches upstream API changes
# features drifted> 10% of features> 25% of featuresMultiple features drifting = systemic issue
New category rate> 1% of values> 5% of valuesUnknown categories in categorical features
Pro Tip: Windowed Drift Detection
Don't compare against just the training data. Maintain a rolling reference window (e.g., last 7 days) alongside the training baseline. This helps distinguish seasonal patterns from true drift. If both the training baseline AND the rolling window show drift, it's likely a real issue.
Common Mistake: Too Many Alerts
If you monitor 100 features with p-value < 0.05, you'll get ~5 false positives per check by random chance. Use Bonferroni correction (divide threshold by number of features) or focus alerts on the top-N most important features rather than alerting on every feature independently.