Intermediate

Step 3: Data Drift Monitoring

Detect when your input data changes by comparing feature distributions with PSI and KS tests. Visualize drift with overlay histograms and trigger alerts when thresholds are crossed.

Why Drift Monitoring Matters

ML models are trained on historical data. When the real-world data distribution shifts, model accuracy degrades — even if the model code is unchanged. Common causes include:

  • Upstream data changes: A partner API changes their schema or units
  • Seasonal patterns: Holiday shopping behavior differs from normal
  • Population shift: New user demographics join the platform
  • Data pipeline bugs: A join starts returning nulls after a schema migration
💡
Key insight: Drift detection catches model degradation before ground truth labels arrive. You can detect drift in minutes, but labels (actual outcomes) may take days or weeks to materialize.

Statistical Tests for Drift

Create utils/stats.py with the two most widely used drift detection methods:

# utils/stats.py
import numpy as np
from scipy import stats


def calculate_psi(reference, current, n_bins=10):
    """
    Population Stability Index (PSI).
    Measures how much a distribution has shifted.

    PSI < 0.1  : No significant drift
    PSI 0.1-0.2: Moderate drift (investigate)
    PSI > 0.2  : Significant drift (action needed)

    Args:
        reference: Array of reference (training) values
        current: Array of current (production) values
        n_bins: Number of histogram bins

    Returns:
        float: PSI value
    """
    # Create bins from the reference distribution
    min_val = min(reference.min(), current.min())
    max_val = max(reference.max(), current.max())
    bins = np.linspace(min_val, max_val, n_bins + 1)

    # Calculate proportions in each bin
    ref_counts, _ = np.histogram(reference, bins=bins)
    cur_counts, _ = np.histogram(current, bins=bins)

    # Convert to proportions (add small epsilon to avoid division by zero)
    eps = 1e-6
    ref_pct = ref_counts / len(reference) + eps
    cur_pct = cur_counts / len(current) + eps

    # PSI formula: sum of (current - reference) * ln(current / reference)
    psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
    return round(psi, 6)


def calculate_ks_test(reference, current):
    """
    Kolmogorov-Smirnov two-sample test.
    Tests whether two samples come from the same distribution.

    Args:
        reference: Array of reference values
        current: Array of current values

    Returns:
        tuple: (ks_statistic, p_value)
        - ks_statistic: Maximum distance between CDFs (0 to 1)
        - p_value: Probability distributions are the same
          p < 0.05 = statistically significant drift
    """
    statistic, p_value = stats.ks_2samp(reference, current)
    return round(statistic, 6), round(p_value, 6)


def get_drift_severity(psi, ks_p_value):
    """Classify drift severity based on PSI and KS test results."""
    if psi > 0.2 or ks_p_value < 0.01:
        return "critical", "#ef4444"
    elif psi > 0.1 or ks_p_value < 0.05:
        return "warning", "#f59e0b"
    else:
        return "stable", "#22c55e"

Distribution Overlay Charts

Create views/drift_monitor.py with visualizations that compare reference and current distributions:

# views/drift_monitor.py
import streamlit as st
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
from utils.stats import calculate_psi, calculate_ks_test, get_drift_severity


def render_drift_overview(reference_df, current_df, feature_names):
    """
    Summary table showing drift status for all features.
    """
    st.subheader("Drift Overview")

    results = []
    for feat in feature_names:
        ref = reference_df[feat].values
        cur = current_df[feat].values

        psi = calculate_psi(ref, cur)
        ks_stat, ks_p = calculate_ks_test(ref, cur)
        severity, color = get_drift_severity(psi, ks_p)

        results.append({
            "Feature": feat,
            "PSI": psi,
            "KS Statistic": ks_stat,
            "KS p-value": ks_p,
            "Status": severity.upper(),
        })

    results_df = pd.DataFrame(results)

    # Show alert count
    critical_count = len(results_df[results_df["Status"] == "CRITICAL"])
    warning_count = len(results_df[results_df["Status"] == "WARNING"])

    col1, col2, col3 = st.columns(3)
    with col1:
        st.metric("Features Monitored", len(feature_names))
    with col2:
        st.metric("Critical Drift", critical_count,
                   delta=f"{critical_count} features", delta_color="inverse")
    with col3:
        st.metric("Warnings", warning_count,
                   delta=f"{warning_count} features", delta_color="inverse")

    # Display the results table
    st.dataframe(
        results_df.style.apply(
            lambda row: [
                "background-color: #dc262620" if row["Status"] == "CRITICAL"
                else "background-color: #f59e0b20" if row["Status"] == "WARNING"
                else "" for _ in row
            ],
            axis=1,
        ),
        use_container_width=True,
        hide_index=True,
    )

Feature Distribution Comparison

Overlay histograms show exactly how a feature has shifted between training and production:

def render_distribution_comparison(reference_df, current_df, feature_name):
    """
    Overlaid histogram comparing reference and current distributions.
    """
    st.subheader(f"Distribution: {feature_name}")

    ref_vals = reference_df[feature_name].values
    cur_vals = current_df[feature_name].values

    # Calculate drift metrics
    psi = calculate_psi(ref_vals, cur_vals)
    ks_stat, ks_p = calculate_ks_test(ref_vals, cur_vals)
    severity, color = get_drift_severity(psi, ks_p)

    # Show metrics inline
    col1, col2, col3 = st.columns(3)
    with col1:
        st.metric("PSI", f"{psi:.4f}")
    with col2:
        st.metric("KS Statistic", f"{ks_stat:.4f}")
    with col3:
        st.metric("Status", severity.upper())

    # Overlay histogram
    fig = go.Figure()

    fig.add_trace(go.Histogram(
        x=ref_vals,
        name="Reference (Training)",
        opacity=0.6,
        marker_color="#6366f1",
        nbinsx=40,
    ))

    fig.add_trace(go.Histogram(
        x=cur_vals,
        name="Current (Production)",
        opacity=0.6,
        marker_color="#f59e0b",
        nbinsx=40,
    ))

    fig.update_layout(
        barmode="overlay",
        title=f"{feature_name} Distribution Comparison",
        xaxis_title="Value",
        yaxis_title="Count",
        height=400,
        template="plotly_dark",
        margin=dict(l=20, r=20, t=50, b=20),
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
    )

    st.plotly_chart(fig, use_container_width=True)

PSI Heatmap Across Features

A heatmap showing PSI values for all features over time helps identify systematic drift patterns:

def render_psi_heatmap(reference_df, current_df, feature_names):
    """
    Heatmap showing PSI values for all features.
    Simulates weekly PSI values for the last 8 weeks.
    """
    st.subheader("PSI Trend (Last 8 Weeks)")

    weeks = [f"Week {i+1}" for i in range(8)]
    psi_matrix = []

    for feat in feature_names:
        ref = reference_df[feat].values
        cur = current_df[feat].values

        row = []
        for w in range(8):
            # Simulate increasing drift over time
            drift_factor = 1 + w * 0.15
            shifted = cur * drift_factor + np.random.normal(0, 0.1, len(cur))
            psi = calculate_psi(ref, shifted)
            row.append(min(psi, 0.5))  # Cap for visualization
        psi_matrix.append(row)

    fig = go.Figure(data=go.Heatmap(
        z=psi_matrix,
        x=weeks,
        y=feature_names,
        colorscale=[
            [0, "#22c55e"],      # Green = no drift
            [0.2, "#f59e0b"],    # Yellow = warning
            [0.5, "#ef4444"],    # Red = critical
            [1.0, "#7f1d1d"],    # Dark red = severe
        ],
        text=[[f"{v:.3f}" for v in row] for row in psi_matrix],
        texttemplate="%{text}",
        hovertemplate="Feature: %{y}<br>%{x}<br>PSI: %{z:.4f}<extra></extra>",
        zmin=0,
        zmax=0.5,
    ))

    fig.update_layout(
        title="Feature Drift PSI Over Time",
        height=400,
        template="plotly_dark",
        margin=dict(l=20, r=20, t=50, b=20),
    )

    st.plotly_chart(fig, use_container_width=True)

Wiring It Into the App

Update app.py to connect the drift monitoring views:

# In app.py, add drift monitoring:
from views.drift_monitor import (
    render_drift_overview,
    render_distribution_comparison,
    render_psi_heatmap,
)
from connectors import get_drift_data

# Inside main(), replace the Data Drift branch:
elif page == "Data Drift":
    reference_df, current_df, feature_names = get_drift_data()

    render_drift_overview(reference_df, current_df, feature_names)
    render_psi_heatmap(reference_df, current_df, feature_names)

    st.markdown("---")
    st.subheader("Feature Deep Dive")
    selected_feature = st.selectbox("Select Feature", feature_names)
    render_distribution_comparison(reference_df, current_df, selected_feature)
Production tip: In production, store reference distributions at training time and load them from S3 or a database. Never recompute reference distributions from production data, as that defeats the purpose of drift detection.

What Is Next

Drift monitoring is in place. In the next lesson, we will build Step 4: Cost and Infrastructure — GPU utilization gauges, API cost breakdowns, inference latency percentile charts, and resource usage time series.