Step 3: Data Drift Monitoring
Detect when your input data changes by comparing feature distributions with PSI and KS tests. Visualize drift with overlay histograms and trigger alerts when thresholds are crossed.
Why Drift Monitoring Matters
ML models are trained on historical data. When the real-world data distribution shifts, model accuracy degrades — even if the model code is unchanged. Common causes include:
- Upstream data changes: A partner API changes their schema or units
- Seasonal patterns: Holiday shopping behavior differs from normal
- Population shift: New user demographics join the platform
- Data pipeline bugs: A join starts returning nulls after a schema migration
Statistical Tests for Drift
Create utils/stats.py with the two most widely used drift detection methods:
# utils/stats.py
import numpy as np
from scipy import stats
def calculate_psi(reference, current, n_bins=10):
"""
Population Stability Index (PSI).
Measures how much a distribution has shifted.
PSI < 0.1 : No significant drift
PSI 0.1-0.2: Moderate drift (investigate)
PSI > 0.2 : Significant drift (action needed)
Args:
reference: Array of reference (training) values
current: Array of current (production) values
n_bins: Number of histogram bins
Returns:
float: PSI value
"""
# Create bins from the reference distribution
min_val = min(reference.min(), current.min())
max_val = max(reference.max(), current.max())
bins = np.linspace(min_val, max_val, n_bins + 1)
# Calculate proportions in each bin
ref_counts, _ = np.histogram(reference, bins=bins)
cur_counts, _ = np.histogram(current, bins=bins)
# Convert to proportions (add small epsilon to avoid division by zero)
eps = 1e-6
ref_pct = ref_counts / len(reference) + eps
cur_pct = cur_counts / len(current) + eps
# PSI formula: sum of (current - reference) * ln(current / reference)
psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
return round(psi, 6)
def calculate_ks_test(reference, current):
"""
Kolmogorov-Smirnov two-sample test.
Tests whether two samples come from the same distribution.
Args:
reference: Array of reference values
current: Array of current values
Returns:
tuple: (ks_statistic, p_value)
- ks_statistic: Maximum distance between CDFs (0 to 1)
- p_value: Probability distributions are the same
p < 0.05 = statistically significant drift
"""
statistic, p_value = stats.ks_2samp(reference, current)
return round(statistic, 6), round(p_value, 6)
def get_drift_severity(psi, ks_p_value):
"""Classify drift severity based on PSI and KS test results."""
if psi > 0.2 or ks_p_value < 0.01:
return "critical", "#ef4444"
elif psi > 0.1 or ks_p_value < 0.05:
return "warning", "#f59e0b"
else:
return "stable", "#22c55e"
Distribution Overlay Charts
Create views/drift_monitor.py with visualizations that compare reference and current distributions:
# views/drift_monitor.py
import streamlit as st
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
from utils.stats import calculate_psi, calculate_ks_test, get_drift_severity
def render_drift_overview(reference_df, current_df, feature_names):
"""
Summary table showing drift status for all features.
"""
st.subheader("Drift Overview")
results = []
for feat in feature_names:
ref = reference_df[feat].values
cur = current_df[feat].values
psi = calculate_psi(ref, cur)
ks_stat, ks_p = calculate_ks_test(ref, cur)
severity, color = get_drift_severity(psi, ks_p)
results.append({
"Feature": feat,
"PSI": psi,
"KS Statistic": ks_stat,
"KS p-value": ks_p,
"Status": severity.upper(),
})
results_df = pd.DataFrame(results)
# Show alert count
critical_count = len(results_df[results_df["Status"] == "CRITICAL"])
warning_count = len(results_df[results_df["Status"] == "WARNING"])
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Features Monitored", len(feature_names))
with col2:
st.metric("Critical Drift", critical_count,
delta=f"{critical_count} features", delta_color="inverse")
with col3:
st.metric("Warnings", warning_count,
delta=f"{warning_count} features", delta_color="inverse")
# Display the results table
st.dataframe(
results_df.style.apply(
lambda row: [
"background-color: #dc262620" if row["Status"] == "CRITICAL"
else "background-color: #f59e0b20" if row["Status"] == "WARNING"
else "" for _ in row
],
axis=1,
),
use_container_width=True,
hide_index=True,
)
Feature Distribution Comparison
Overlay histograms show exactly how a feature has shifted between training and production:
def render_distribution_comparison(reference_df, current_df, feature_name):
"""
Overlaid histogram comparing reference and current distributions.
"""
st.subheader(f"Distribution: {feature_name}")
ref_vals = reference_df[feature_name].values
cur_vals = current_df[feature_name].values
# Calculate drift metrics
psi = calculate_psi(ref_vals, cur_vals)
ks_stat, ks_p = calculate_ks_test(ref_vals, cur_vals)
severity, color = get_drift_severity(psi, ks_p)
# Show metrics inline
col1, col2, col3 = st.columns(3)
with col1:
st.metric("PSI", f"{psi:.4f}")
with col2:
st.metric("KS Statistic", f"{ks_stat:.4f}")
with col3:
st.metric("Status", severity.upper())
# Overlay histogram
fig = go.Figure()
fig.add_trace(go.Histogram(
x=ref_vals,
name="Reference (Training)",
opacity=0.6,
marker_color="#6366f1",
nbinsx=40,
))
fig.add_trace(go.Histogram(
x=cur_vals,
name="Current (Production)",
opacity=0.6,
marker_color="#f59e0b",
nbinsx=40,
))
fig.update_layout(
barmode="overlay",
title=f"{feature_name} Distribution Comparison",
xaxis_title="Value",
yaxis_title="Count",
height=400,
template="plotly_dark",
margin=dict(l=20, r=20, t=50, b=20),
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
)
st.plotly_chart(fig, use_container_width=True)
PSI Heatmap Across Features
A heatmap showing PSI values for all features over time helps identify systematic drift patterns:
def render_psi_heatmap(reference_df, current_df, feature_names):
"""
Heatmap showing PSI values for all features.
Simulates weekly PSI values for the last 8 weeks.
"""
st.subheader("PSI Trend (Last 8 Weeks)")
weeks = [f"Week {i+1}" for i in range(8)]
psi_matrix = []
for feat in feature_names:
ref = reference_df[feat].values
cur = current_df[feat].values
row = []
for w in range(8):
# Simulate increasing drift over time
drift_factor = 1 + w * 0.15
shifted = cur * drift_factor + np.random.normal(0, 0.1, len(cur))
psi = calculate_psi(ref, shifted)
row.append(min(psi, 0.5)) # Cap for visualization
psi_matrix.append(row)
fig = go.Figure(data=go.Heatmap(
z=psi_matrix,
x=weeks,
y=feature_names,
colorscale=[
[0, "#22c55e"], # Green = no drift
[0.2, "#f59e0b"], # Yellow = warning
[0.5, "#ef4444"], # Red = critical
[1.0, "#7f1d1d"], # Dark red = severe
],
text=[[f"{v:.3f}" for v in row] for row in psi_matrix],
texttemplate="%{text}",
hovertemplate="Feature: %{y}<br>%{x}<br>PSI: %{z:.4f}<extra></extra>",
zmin=0,
zmax=0.5,
))
fig.update_layout(
title="Feature Drift PSI Over Time",
height=400,
template="plotly_dark",
margin=dict(l=20, r=20, t=50, b=20),
)
st.plotly_chart(fig, use_container_width=True)
Wiring It Into the App
Update app.py to connect the drift monitoring views:
# In app.py, add drift monitoring:
from views.drift_monitor import (
render_drift_overview,
render_distribution_comparison,
render_psi_heatmap,
)
from connectors import get_drift_data
# Inside main(), replace the Data Drift branch:
elif page == "Data Drift":
reference_df, current_df, feature_names = get_drift_data()
render_drift_overview(reference_df, current_df, feature_names)
render_psi_heatmap(reference_df, current_df, feature_names)
st.markdown("---")
st.subheader("Feature Deep Dive")
selected_feature = st.selectbox("Select Feature", feature_names)
render_distribution_comparison(reference_df, current_df, selected_feature)
What Is Next
Drift monitoring is in place. In the next lesson, we will build Step 4: Cost and Infrastructure — GPU utilization gauges, API cost breakdowns, inference latency percentile charts, and resource usage time series.
Lilly Tech Systems