Scaling to Millions of Users
Moving from a prototype recommendation system to one that serves 100M+ users at sub-100ms latency requires careful architectural decisions around embedding storage, model serving, caching, and the balance between pre-computation and real-time inference.
Embedding Table Sharding
At scale, embedding tables for users and items can be enormous. 100M users with 128-dimensional float32 embeddings = 48GB. This does not fit in a single machine's memory, so you need to shard.
import hashlib
from typing import Optional
import numpy as np
class ShardedEmbeddingStore:
"""Distributed embedding storage across multiple shards.
Uses consistent hashing to map entity IDs to shards.
Each shard is a separate Redis instance (or Memcached, or custom store).
"""
def __init__(self, shard_clients: list, embedding_dim: int = 128):
self.shards = shard_clients
self.num_shards = len(shard_clients)
self.embedding_dim = embedding_dim
def _get_shard(self, entity_id: str) -> int:
"""Consistent hashing: same entity always maps to same shard."""
hash_value = int(hashlib.md5(entity_id.encode()).hexdigest(), 16)
return hash_value % self.num_shards
def get_embedding(self, entity_id: str) -> Optional[np.ndarray]:
shard_idx = self._get_shard(entity_id)
client = self.shards[shard_idx]
data = client.get(f"emb:{entity_id}")
if data is None:
return None
return np.frombuffer(data, dtype=np.float32)
def set_embedding(self, entity_id: str, embedding: np.ndarray):
shard_idx = self._get_shard(entity_id)
client = self.shards[shard_idx]
client.set(f"emb:{entity_id}", embedding.astype(np.float32).tobytes())
def get_batch(self, entity_ids: list[str]) -> dict[str, Optional[np.ndarray]]:
"""Batch fetch embeddings, grouping by shard for efficiency."""
# Group IDs by shard
shard_groups = {}
for entity_id in entity_ids:
shard_idx = self._get_shard(entity_id)
if shard_idx not in shard_groups:
shard_groups[shard_idx] = []
shard_groups[shard_idx].append(entity_id)
results = {}
for shard_idx, ids in shard_groups.items():
client = self.shards[shard_idx]
pipe = client.pipeline()
for entity_id in ids:
pipe.get(f"emb:{entity_id}")
values = pipe.execute()
for entity_id, value in zip(ids, values):
if value is not None:
results[entity_id] = np.frombuffer(value, dtype=np.float32)
else:
results[entity_id] = None
return results
Memory Planning
| Scale | Users | Items | Embedding Dim | Total Memory | Shards Needed |
|---|---|---|---|---|---|
| Small | 1M | 100K | 128 | ~560MB | 1 |
| Medium | 10M | 1M | 128 | ~5.6GB | 2–4 |
| Large | 100M | 10M | 128 | ~56GB | 8–16 |
| Massive | 1B | 100M | 256 | ~1.1TB | 64–128 |
Distributed Model Inference
Serving a ranking model at 100K+ QPS requires horizontal scaling of the inference service. The key challenge is balancing latency (fast enough for real-time) with throughput (handle the traffic).
import asyncio
from fastapi import FastAPI
import torch
import numpy as np
from concurrent.futures import ThreadPoolExecutor
app = FastAPI()
class ModelServingCluster:
"""Distributed model serving with load balancing and batching.
Key optimizations:
1. Request batching: group requests and run batch inference
2. Model parallelism: split model across GPUs for large models
3. Request queuing: smooth out traffic spikes
"""
def __init__(self, model_path: str, num_workers: int = 4, batch_size: int = 64,
max_batch_wait_ms: float = 5.0):
self.model = torch.jit.load(model_path)
self.model.eval()
self.batch_size = batch_size
self.max_batch_wait_ms = max_batch_wait_ms
self.request_queue = asyncio.Queue()
self.executor = ThreadPoolExecutor(max_workers=num_workers)
async def predict(self, features: np.ndarray) -> float:
"""Submit a prediction request and wait for result."""
future = asyncio.get_event_loop().create_future()
await self.request_queue.put((features, future))
return await future
async def batch_inference_loop(self):
"""Continuously batch requests and run inference."""
while True:
batch_features = []
batch_futures = []
# Wait for at least one request
features, future = await self.request_queue.get()
batch_features.append(features)
batch_futures.append(future)
# Collect more requests up to batch_size or timeout
deadline = asyncio.get_event_loop().time() + self.max_batch_wait_ms / 1000
while len(batch_features) < self.batch_size:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
break
try:
features, future = await asyncio.wait_for(
self.request_queue.get(), timeout=remaining
)
batch_features.append(features)
batch_futures.append(future)
except asyncio.TimeoutError:
break
# Run batch inference
batch_tensor = torch.tensor(np.stack(batch_features), dtype=torch.float32)
with torch.no_grad():
scores = self.model(batch_tensor).numpy()
# Return results to individual requests
for future, score in zip(batch_futures, scores):
future.set_result(float(score))
Caching Strategies
Caching is the single most impactful optimization for recommendation systems at scale. The key insight is that many users have similar contexts, and recommendations do not need to be unique for every request.
import hashlib
import json
import time
from typing import Optional
class RecommendationCache:
"""Multi-layer caching for recommendations.
Layer 1: CDN/Edge cache for anonymous or segment-based recs (TTL: 5-15 min)
Layer 2: Redis cache for user-level recs (TTL: 1-5 min)
Layer 3: Local in-memory cache for hot users (TTL: 30 sec)
"""
def __init__(self, redis_client, local_cache_size=10000):
self.redis = redis_client
self.local_cache = {} # Simple LRU would be better in production
self.local_cache_size = local_cache_size
self.stats = {"hits": 0, "misses": 0}
def _cache_key(self, user_id: str, context: dict) -> str:
"""Generate cache key from user and context.
Context includes page_type, device, etc. but NOT time-sensitive fields.
"""
context_hash = hashlib.md5(
json.dumps(context, sort_keys=True).encode()
).hexdigest()[:8]
return f"recs:{user_id}:{context_hash}"
def get(self, user_id: str, context: dict) -> Optional[list]:
key = self._cache_key(user_id, context)
# Layer 3: Local cache (fastest)
if key in self.local_cache:
entry = self.local_cache[key]
if time.time() < entry["expires"]:
self.stats["hits"] += 1
return entry["data"]
else:
del self.local_cache[key]
# Layer 2: Redis cache
cached = self.redis.get(key)
if cached:
data = json.loads(cached)
# Promote to local cache
self.local_cache[key] = {
"data": data, "expires": time.time() + 30
}
self._evict_local_if_needed()
self.stats["hits"] += 1
return data
self.stats["misses"] += 1
return None
def set(self, user_id: str, context: dict, recommendations: list,
ttl_seconds: int = 120):
key = self._cache_key(user_id, context)
# Set in Redis with TTL
self.redis.setex(key, ttl_seconds, json.dumps(recommendations))
# Set in local cache with shorter TTL
self.local_cache[key] = {
"data": recommendations, "expires": time.time() + 30
}
self._evict_local_if_needed()
def invalidate(self, user_id: str):
"""Invalidate all cached recommendations for a user.
Call this after a significant user action (purchase, profile update).
"""
pattern = f"recs:{user_id}:*"
keys = self.redis.keys(pattern)
if keys:
self.redis.delete(*keys)
# Clear local cache for this user
to_delete = [k for k in self.local_cache if k.startswith(f"recs:{user_id}:")]
for k in to_delete:
del self.local_cache[k]
def _evict_local_if_needed(self):
if len(self.local_cache) > self.local_cache_size:
# Evict oldest entries
sorted_keys = sorted(
self.local_cache.keys(),
key=lambda k: self.local_cache[k]["expires"]
)
for key in sorted_keys[:len(self.local_cache) - self.local_cache_size]:
del self.local_cache[key]
def get_hit_rate(self) -> float:
total = self.stats["hits"] + self.stats["misses"]
return self.stats["hits"] / total if total > 0 else 0.0
Pre-Computation vs Real-Time Trade-offs
Not all recommendations need to be computed in real time. Understanding when to pre-compute vs compute on-the-fly is critical for scaling.
| Aspect | Pre-Computed (Offline) | Real-Time (Online) |
|---|---|---|
| Latency | <1ms (just a cache lookup) | 50–200ms (full pipeline) |
| Freshness | Stale (hours old) | Fresh (incorporates recent actions) |
| Cost per request | Near zero (amortized) | High (model inference + feature fetching) |
| Personalization | Based on user history up to batch time | Incorporates real-time session context |
| Best for | Email, push notifications, homepage hero | Search results, product pages, in-session recs |
class HybridRecommender:
"""Combines pre-computed and real-time recommendations.
Pre-computed recs serve as a fast fallback and provide the base list.
Real-time adjustments personalize based on current session.
"""
def __init__(self, precomputed_store, realtime_ranker, cache):
self.precomputed = precomputed_store
self.realtime = realtime_ranker
self.cache = cache
async def recommend(self, user_id: str, context: dict, n: int = 20,
timeout_ms: int = 100) -> list:
# 1. Check cache first
cached = self.cache.get(user_id, context)
if cached:
return cached[:n]
# 2. Fetch pre-computed recs (always fast, ~1ms)
precomputed_recs = self.precomputed.get(user_id, n=50)
# 3. Try real-time re-ranking with timeout
try:
realtime_recs = await asyncio.wait_for(
self.realtime.rerank(
user_id, precomputed_recs, context
),
timeout=timeout_ms / 1000
)
result = realtime_recs[:n]
except asyncio.TimeoutError:
# Fallback to pre-computed if real-time is too slow
result = precomputed_recs[:n]
print(f"[HybridRec] Real-time timeout for user={user_id}, "
f"using pre-computed fallback")
# 4. Cache the result
self.cache.set(user_id, context, result, ttl_seconds=120)
return result
System Design for 100M+ Users
Here is the complete architecture for a recommendation system serving 100M+ users at 50K QPS with sub-100ms p99 latency.
# Architecture: Recommendation System at Scale
#
# [Client] --> [CDN/Edge Cache]
# |
# v
# [API Gateway / Load Balancer]
# |
# v
# [Recommendation Service] (horizontally scaled)
# | | |
# v v v
# [Candidate [Feature [Ranking Model
# Generation] Store] Serving]
# | | |
# v v v
# [FAISS Index] [Redis [TorchServe/
# (sharded, Cluster] Triton]
# replicated) |
# v
# [Offline Pipeline]
# (Spark/Airflow)
# |
# [Training Data]
# (S3/BigQuery)
# Capacity planning for 100M users, 50K QPS:
capacity_plan = {
"api_servers": {
"instances": 20,
"cpu_per_instance": 8,
"memory_per_instance_gb": 32,
"notes": "Horizontally scaled behind ALB"
},
"faiss_index": {
"replicas": 4,
"memory_per_replica_gb": 64,
"items_indexed": "50M",
"index_type": "IVFFlat with HNSW quantizer",
"latency_p99_ms": 5
},
"redis_feature_store": {
"shards": 16,
"memory_per_shard_gb": 32,
"total_memory_gb": 512,
"latency_p99_ms": 2
},
"model_serving": {
"gpu_instances": 8,
"gpu_type": "T4",
"batch_size": 64,
"max_batch_wait_ms": 5,
"latency_p99_ms": 15
},
"caching": {
"cdn_hit_rate": "30%",
"redis_hit_rate": "60%",
"effective_qps_to_pipeline": "5K (90% cache hit)"
},
"offline_pipeline": {
"schedule": "Every 6 hours",
"compute": "Spark cluster (100 nodes)",
"data_processed_per_run_tb": 2
}
}
Key Takeaways
- Shard embedding tables across multiple Redis instances using consistent hashing. Plan memory based on: users * embedding_dim * 4 bytes.
- Dynamic batching (collecting requests for 5ms before inference) increases GPU throughput by 5–10x. Use TorchServe or Triton for production.
- Multi-layer caching (CDN, Redis, local) can reduce effective QPS to your ML pipeline by 90%. Invalidate on significant user actions (purchases, profile updates).
- Use a hybrid approach: pre-compute base recommendations offline, then re-rank in real time based on session context. Fall back to pre-computed if real-time times out.
- A 100M-user system at 50K QPS requires approximately 20 API servers, 8 GPU instances, 16 Redis shards, and 4 FAISS replicas — with caching doing most of the heavy lifting.
What Is Next
In the next lesson, we will cover offline and online evaluation — the metrics and experimentation frameworks that tell you whether your recommendation system is actually working and improving over time.
Lilly Tech Systems