Advanced
Real-time Inference with Kafka
Deploy ML models as Kafka consumers that produce predictions in real-time, with strategies for scaling, model updates, and monitoring.
Architecture Pattern
In a Kafka-based inference system, models run as consumers that read events, enrich them with features, run predictions, and produce results to an output topic.
Basic Inference Consumer
Python — ML Inference Consumer
from confluent_kafka import Consumer, Producer
import json
import joblib
import numpy as np
import redis
import time
# Load model
model = joblib.load("fraud_model_v2.pkl")
redis_client = redis.Redis(host='localhost', port=6379)
consumer = Consumer({
'bootstrap.servers': 'localhost:9092',
'group.id': 'fraud-inference',
'auto.offset.reset': 'latest',
'max.poll.interval.ms': 60000
})
consumer.subscribe(['transactions'])
producer = Producer({'bootstrap.servers': 'localhost:9092'})
def get_features(user_id, transaction):
"""Combine real-time features with transaction data."""
# Get precomputed streaming features from Redis
cached = redis_client.get(f"features:{user_id}")
streaming_features = json.loads(cached) if cached else {}
return np.array([[
transaction['amount'],
transaction['merchant_category'],
streaming_features.get('tx_count_1h', 0),
streaming_features.get('avg_amount_1h', 0),
streaming_features.get('unique_merchants_1h', 0),
]])
while True:
msg = consumer.poll(0.1)
if msg is None or msg.error():
continue
start = time.time()
tx = json.loads(msg.value())
# Build features and predict
features = get_features(tx['user_id'], tx)
fraud_prob = model.predict_proba(features)[0][1]
# Produce prediction
result = {
'transaction_id': tx['transaction_id'],
'user_id': tx['user_id'],
'fraud_probability': float(fraud_prob),
'is_fraud': fraud_prob > 0.5,
'model_version': 'v2',
'latency_ms': (time.time() - start) * 1000
}
producer.produce('fraud-predictions', value=json.dumps(result).encode())
producer.poll(0)
Model Hot-Swapping
Python — Hot-swap Models Without Downtime
import threading
import os
class ModelServer:
def __init__(self, model_path):
self.model = joblib.load(model_path)
self.model_path = model_path
self.lock = threading.Lock()
self._start_watcher()
def predict(self, features):
with self.lock:
return self.model.predict_proba(features)
def reload(self, new_path):
new_model = joblib.load(new_path) # Load first
with self.lock:
self.model = new_model # Swap atomically
self.model_path = new_path
print(f"Model swapped to {new_path}")
def _start_watcher(self):
"""Watch for new model versions via Kafka topic."""
def watch():
c = Consumer({
'bootstrap.servers': 'localhost:9092',
'group.id': 'model-updates',
'auto.offset.reset': 'latest'
})
c.subscribe(['model-deployments'])
while True:
msg = c.poll(5.0)
if msg and not msg.error():
update = json.loads(msg.value())
self.reload(update['model_path'])
t = threading.Thread(target=watch, daemon=True)
t.start()
server = ModelServer("fraud_model_v2.pkl")
Scaling Inference
- Horizontal scaling: Add more consumer instances to the same consumer group. Kafka distributes partitions automatically.
- Partition count: Set partitions equal to your maximum expected consumer count. You can't have more consumers than partitions.
- Batch inference: Accumulate messages into mini-batches for GPU inference efficiency.
- Async processing: Use async I/O for feature lookups to avoid blocking on Redis/DB calls.
Monitoring Predictions
Python — Prediction Monitoring Metrics
from prometheus_client import Counter, Histogram, start_http_server
# Metrics
predictions_total = Counter('ml_predictions_total', 'Total predictions', ['model_version'])
prediction_latency = Histogram('ml_prediction_latency_seconds', 'Prediction latency')
fraud_score = Histogram('ml_fraud_score', 'Fraud probability distribution', buckets=[0.1, 0.3, 0.5, 0.7, 0.9])
start_http_server(8000) # Expose metrics endpoint
# In your prediction loop:
with prediction_latency.time():
prob = model.predict_proba(features)[0][1]
predictions_total.labels(model_version='v2').inc()
fraud_score.observe(prob)
Track training-serving skew: Log the feature values used at inference time to a Kafka topic. Compare the distribution of served features against training features to detect data drift early.
Lilly Tech Systems