MLflow Model Registry
Manage the full lifecycle of ML models with versioning, stage transitions, and collaborative approval workflows.
Model Registry Concepts
The Model Registry is a centralized model store that provides:
- Model versioning: Track multiple versions of each registered model.
- Stage management: Move models through stages (Staging, Production, Archived).
- Lineage: Link each model version back to the experiment run that created it.
- Annotations: Add descriptions, tags, and comments for collaboration.
Registering Models
import mlflow
# Method 1: Register during logging
with mlflow.start_run():
model = train_model(X_train, y_train)
mlflow.sklearn.log_model(
model,
artifact_path="model",
registered_model_name="churn-predictor", # Auto-registers
)
# Method 2: Register an existing run's model
result = mlflow.register_model(
model_uri="runs:/abc123def456/model",
name="churn-predictor",
)
print(f"Version: {result.version}")
Model Versions
Each time you register a model with the same name, a new version is created automatically:
from mlflow import MlflowClient
client = MlflowClient()
# List all versions of a model
versions = client.search_model_versions("name='churn-predictor'")
for v in versions:
print(f"Version {v.version}: stage={v.current_stage}, run_id={v.run_id}")
# Get a specific version
version = client.get_model_version("churn-predictor", version=3)
# Update version description
client.update_model_version(
name="churn-predictor",
version=3,
description="Improved model with new features. F1: 0.94"
)
Stage Transitions
Models move through stages as they mature:
None
Initial state when a model is first registered. Not yet evaluated for any environment.
Staging
Model is being tested and validated. Running integration tests, shadow deployments, or A/B tests.
Production
Model is approved and serving live traffic. Only one version should typically be in Production per model.
Archived
Model version is retired. Kept for audit and rollback purposes but no longer serving.
from mlflow import MlflowClient
client = MlflowClient()
# Transition to Staging
client.transition_model_version_stage(
name="churn-predictor",
version=3,
stage="Staging",
)
# After validation, promote to Production
client.transition_model_version_stage(
name="churn-predictor",
version=3,
stage="Production",
archive_existing_versions=True, # Archive the current Production version
)
# Archive an old version
client.transition_model_version_stage(
name="churn-predictor",
version=2,
stage="Archived",
)
Model Approval Workflows
def approve_model(model_name, version, test_data):
"""Validate and promote a model through stages."""
client = MlflowClient()
# Load the staging model
model_uri = f"models:/{model_name}/{version}"
model = mlflow.pyfunc.load_model(model_uri)
# Run validation checks
predictions = model.predict(test_data.X)
accuracy = accuracy_score(test_data.y, predictions)
f1 = f1_score(test_data.y, predictions)
# Check against thresholds
if accuracy >= 0.90 and f1 >= 0.88:
client.transition_model_version_stage(
name=model_name,
version=version,
stage="Production",
archive_existing_versions=True,
)
print(f"Model v{version} promoted to Production!")
else:
print(f"Model v{version} failed validation: acc={accuracy:.3f}, f1={f1:.3f}")
client.set_model_version_tag(model_name, version, "validation", "failed")
Webhooks and Notifications
Set up webhooks to notify your team when models change stage:
- Slack notifications: Alert the team when a model is promoted to production.
- CI/CD triggers: Automatically deploy when a model reaches Production stage.
- Email alerts: Notify stakeholders of model updates.
Collaboration Workflows
Model Tags
Use tags to communicate model status, ownership, and metadata. E.g., team=fraud, reviewed_by=alice, approved=true.
Descriptions
Add detailed descriptions to model versions explaining changes, performance improvements, and known limitations.
Lineage
Every model version links back to its training run, providing full traceability from prediction to training data.
Programmatic Registry Management
from mlflow import MlflowClient
client = MlflowClient()
# Create a registered model
client.create_registered_model(
name="fraud-detector",
description="Real-time fraud detection model for payment transactions",
tags={"team": "payments", "priority": "critical"},
)
# Search registered models
models = client.search_registered_models(filter_string="name LIKE '%churn%'")
# Get the latest Production version
from mlflow.tracking import MlflowClient
prod_version = client.get_latest_versions("churn-predictor", stages=["Production"])
if prod_version:
print(f"Production model: v{prod_version[0].version}")
# Delete a model version (use with caution)
client.delete_model_version(name="churn-predictor", version=1)
# Delete entire registered model (use with extreme caution)
# client.delete_registered_model(name="old-model")
Lilly Tech Systems