Advanced

Scaling to Millions of Users

Moving from a prototype recommendation system to one that serves 100M+ users at sub-100ms latency requires careful architectural decisions around embedding storage, model serving, caching, and the balance between pre-computation and real-time inference.

Embedding Table Sharding

At scale, embedding tables for users and items can be enormous. 100M users with 128-dimensional float32 embeddings = 48GB. This does not fit in a single machine's memory, so you need to shard.

import hashlib
from typing import Optional
import numpy as np

class ShardedEmbeddingStore:
    """Distributed embedding storage across multiple shards.

    Uses consistent hashing to map entity IDs to shards.
    Each shard is a separate Redis instance (or Memcached, or custom store).
    """
    def __init__(self, shard_clients: list, embedding_dim: int = 128):
        self.shards = shard_clients
        self.num_shards = len(shard_clients)
        self.embedding_dim = embedding_dim

    def _get_shard(self, entity_id: str) -> int:
        """Consistent hashing: same entity always maps to same shard."""
        hash_value = int(hashlib.md5(entity_id.encode()).hexdigest(), 16)
        return hash_value % self.num_shards

    def get_embedding(self, entity_id: str) -> Optional[np.ndarray]:
        shard_idx = self._get_shard(entity_id)
        client = self.shards[shard_idx]
        data = client.get(f"emb:{entity_id}")
        if data is None:
            return None
        return np.frombuffer(data, dtype=np.float32)

    def set_embedding(self, entity_id: str, embedding: np.ndarray):
        shard_idx = self._get_shard(entity_id)
        client = self.shards[shard_idx]
        client.set(f"emb:{entity_id}", embedding.astype(np.float32).tobytes())

    def get_batch(self, entity_ids: list[str]) -> dict[str, Optional[np.ndarray]]:
        """Batch fetch embeddings, grouping by shard for efficiency."""
        # Group IDs by shard
        shard_groups = {}
        for entity_id in entity_ids:
            shard_idx = self._get_shard(entity_id)
            if shard_idx not in shard_groups:
                shard_groups[shard_idx] = []
            shard_groups[shard_idx].append(entity_id)

        results = {}
        for shard_idx, ids in shard_groups.items():
            client = self.shards[shard_idx]
            pipe = client.pipeline()
            for entity_id in ids:
                pipe.get(f"emb:{entity_id}")
            values = pipe.execute()

            for entity_id, value in zip(ids, values):
                if value is not None:
                    results[entity_id] = np.frombuffer(value, dtype=np.float32)
                else:
                    results[entity_id] = None

        return results

Memory Planning

ScaleUsersItemsEmbedding DimTotal MemoryShards Needed
Small 1M 100K 128 ~560MB 1
Medium 10M 1M 128 ~5.6GB 2–4
Large 100M 10M 128 ~56GB 8–16
Massive 1B 100M 256 ~1.1TB 64–128

Distributed Model Inference

Serving a ranking model at 100K+ QPS requires horizontal scaling of the inference service. The key challenge is balancing latency (fast enough for real-time) with throughput (handle the traffic).

import asyncio
from fastapi import FastAPI
import torch
import numpy as np
from concurrent.futures import ThreadPoolExecutor

app = FastAPI()

class ModelServingCluster:
    """Distributed model serving with load balancing and batching.

    Key optimizations:
    1. Request batching: group requests and run batch inference
    2. Model parallelism: split model across GPUs for large models
    3. Request queuing: smooth out traffic spikes
    """
    def __init__(self, model_path: str, num_workers: int = 4, batch_size: int = 64,
                 max_batch_wait_ms: float = 5.0):
        self.model = torch.jit.load(model_path)
        self.model.eval()
        self.batch_size = batch_size
        self.max_batch_wait_ms = max_batch_wait_ms
        self.request_queue = asyncio.Queue()
        self.executor = ThreadPoolExecutor(max_workers=num_workers)

    async def predict(self, features: np.ndarray) -> float:
        """Submit a prediction request and wait for result."""
        future = asyncio.get_event_loop().create_future()
        await self.request_queue.put((features, future))
        return await future

    async def batch_inference_loop(self):
        """Continuously batch requests and run inference."""
        while True:
            batch_features = []
            batch_futures = []

            # Wait for at least one request
            features, future = await self.request_queue.get()
            batch_features.append(features)
            batch_futures.append(future)

            # Collect more requests up to batch_size or timeout
            deadline = asyncio.get_event_loop().time() + self.max_batch_wait_ms / 1000
            while len(batch_features) < self.batch_size:
                remaining = deadline - asyncio.get_event_loop().time()
                if remaining <= 0:
                    break
                try:
                    features, future = await asyncio.wait_for(
                        self.request_queue.get(), timeout=remaining
                    )
                    batch_features.append(features)
                    batch_futures.append(future)
                except asyncio.TimeoutError:
                    break

            # Run batch inference
            batch_tensor = torch.tensor(np.stack(batch_features), dtype=torch.float32)
            with torch.no_grad():
                scores = self.model(batch_tensor).numpy()

            # Return results to individual requests
            for future, score in zip(batch_futures, scores):
                future.set_result(float(score))
💡
Apply at work: Dynamic batching (collecting requests for a few milliseconds before running inference) can increase throughput by 5–10x on GPU. Set the max_batch_wait_ms to 5ms for real-time use cases. This adds 5ms to p50 latency but dramatically reduces cost per prediction. Use TorchServe, Triton Inference Server, or Ray Serve for production deployments.

Caching Strategies

Caching is the single most impactful optimization for recommendation systems at scale. The key insight is that many users have similar contexts, and recommendations do not need to be unique for every request.

import hashlib
import json
import time
from typing import Optional

class RecommendationCache:
    """Multi-layer caching for recommendations.

    Layer 1: CDN/Edge cache for anonymous or segment-based recs (TTL: 5-15 min)
    Layer 2: Redis cache for user-level recs (TTL: 1-5 min)
    Layer 3: Local in-memory cache for hot users (TTL: 30 sec)
    """
    def __init__(self, redis_client, local_cache_size=10000):
        self.redis = redis_client
        self.local_cache = {}  # Simple LRU would be better in production
        self.local_cache_size = local_cache_size
        self.stats = {"hits": 0, "misses": 0}

    def _cache_key(self, user_id: str, context: dict) -> str:
        """Generate cache key from user and context.

        Context includes page_type, device, etc. but NOT time-sensitive fields.
        """
        context_hash = hashlib.md5(
            json.dumps(context, sort_keys=True).encode()
        ).hexdigest()[:8]
        return f"recs:{user_id}:{context_hash}"

    def get(self, user_id: str, context: dict) -> Optional[list]:
        key = self._cache_key(user_id, context)

        # Layer 3: Local cache (fastest)
        if key in self.local_cache:
            entry = self.local_cache[key]
            if time.time() < entry["expires"]:
                self.stats["hits"] += 1
                return entry["data"]
            else:
                del self.local_cache[key]

        # Layer 2: Redis cache
        cached = self.redis.get(key)
        if cached:
            data = json.loads(cached)
            # Promote to local cache
            self.local_cache[key] = {
                "data": data, "expires": time.time() + 30
            }
            self._evict_local_if_needed()
            self.stats["hits"] += 1
            return data

        self.stats["misses"] += 1
        return None

    def set(self, user_id: str, context: dict, recommendations: list,
            ttl_seconds: int = 120):
        key = self._cache_key(user_id, context)

        # Set in Redis with TTL
        self.redis.setex(key, ttl_seconds, json.dumps(recommendations))

        # Set in local cache with shorter TTL
        self.local_cache[key] = {
            "data": recommendations, "expires": time.time() + 30
        }
        self._evict_local_if_needed()

    def invalidate(self, user_id: str):
        """Invalidate all cached recommendations for a user.

        Call this after a significant user action (purchase, profile update).
        """
        pattern = f"recs:{user_id}:*"
        keys = self.redis.keys(pattern)
        if keys:
            self.redis.delete(*keys)

        # Clear local cache for this user
        to_delete = [k for k in self.local_cache if k.startswith(f"recs:{user_id}:")]
        for k in to_delete:
            del self.local_cache[k]

    def _evict_local_if_needed(self):
        if len(self.local_cache) > self.local_cache_size:
            # Evict oldest entries
            sorted_keys = sorted(
                self.local_cache.keys(),
                key=lambda k: self.local_cache[k]["expires"]
            )
            for key in sorted_keys[:len(self.local_cache) - self.local_cache_size]:
                del self.local_cache[key]

    def get_hit_rate(self) -> float:
        total = self.stats["hits"] + self.stats["misses"]
        return self.stats["hits"] / total if total > 0 else 0.0

Pre-Computation vs Real-Time Trade-offs

Not all recommendations need to be computed in real time. Understanding when to pre-compute vs compute on-the-fly is critical for scaling.

AspectPre-Computed (Offline)Real-Time (Online)
Latency <1ms (just a cache lookup) 50–200ms (full pipeline)
Freshness Stale (hours old) Fresh (incorporates recent actions)
Cost per request Near zero (amortized) High (model inference + feature fetching)
Personalization Based on user history up to batch time Incorporates real-time session context
Best for Email, push notifications, homepage hero Search results, product pages, in-session recs
class HybridRecommender:
    """Combines pre-computed and real-time recommendations.

    Pre-computed recs serve as a fast fallback and provide the base list.
    Real-time adjustments personalize based on current session.
    """
    def __init__(self, precomputed_store, realtime_ranker, cache):
        self.precomputed = precomputed_store
        self.realtime = realtime_ranker
        self.cache = cache

    async def recommend(self, user_id: str, context: dict, n: int = 20,
                       timeout_ms: int = 100) -> list:
        # 1. Check cache first
        cached = self.cache.get(user_id, context)
        if cached:
            return cached[:n]

        # 2. Fetch pre-computed recs (always fast, ~1ms)
        precomputed_recs = self.precomputed.get(user_id, n=50)

        # 3. Try real-time re-ranking with timeout
        try:
            realtime_recs = await asyncio.wait_for(
                self.realtime.rerank(
                    user_id, precomputed_recs, context
                ),
                timeout=timeout_ms / 1000
            )
            result = realtime_recs[:n]
        except asyncio.TimeoutError:
            # Fallback to pre-computed if real-time is too slow
            result = precomputed_recs[:n]
            print(f"[HybridRec] Real-time timeout for user={user_id}, "
                  f"using pre-computed fallback")

        # 4. Cache the result
        self.cache.set(user_id, context, result, ttl_seconds=120)

        return result

System Design for 100M+ Users

Here is the complete architecture for a recommendation system serving 100M+ users at 50K QPS with sub-100ms p99 latency.

# Architecture: Recommendation System at Scale
#
# [Client] --> [CDN/Edge Cache]
#                 |
#                 v
#          [API Gateway / Load Balancer]
#                 |
#                 v
#     [Recommendation Service] (horizontally scaled)
#          |          |           |
#          v          v           v
#  [Candidate    [Feature     [Ranking Model
#   Generation]   Store]       Serving]
#     |              |             |
#     v              v             v
# [FAISS Index]  [Redis       [TorchServe/
#  (sharded,      Cluster]     Triton]
#  replicated)       |
#                    v
#              [Offline Pipeline]
#               (Spark/Airflow)
#                    |
#              [Training Data]
#               (S3/BigQuery)

# Capacity planning for 100M users, 50K QPS:
capacity_plan = {
    "api_servers": {
        "instances": 20,
        "cpu_per_instance": 8,
        "memory_per_instance_gb": 32,
        "notes": "Horizontally scaled behind ALB"
    },
    "faiss_index": {
        "replicas": 4,
        "memory_per_replica_gb": 64,
        "items_indexed": "50M",
        "index_type": "IVFFlat with HNSW quantizer",
        "latency_p99_ms": 5
    },
    "redis_feature_store": {
        "shards": 16,
        "memory_per_shard_gb": 32,
        "total_memory_gb": 512,
        "latency_p99_ms": 2
    },
    "model_serving": {
        "gpu_instances": 8,
        "gpu_type": "T4",
        "batch_size": 64,
        "max_batch_wait_ms": 5,
        "latency_p99_ms": 15
    },
    "caching": {
        "cdn_hit_rate": "30%",
        "redis_hit_rate": "60%",
        "effective_qps_to_pipeline": "5K (90% cache hit)"
    },
    "offline_pipeline": {
        "schedule": "Every 6 hours",
        "compute": "Spark cluster (100 nodes)",
        "data_processed_per_run_tb": 2
    }
}
📝
Production reality: Caching is your most powerful scaling lever. At Netflix, recommendation results are pre-computed for all active users and cached. Real-time computation only happens for session-level adjustments (re-ranking based on what you just watched). This reduces the effective QPS to the ML pipeline by 90%. Design your caching strategy before adding more inference servers.

Key Takeaways

  • Shard embedding tables across multiple Redis instances using consistent hashing. Plan memory based on: users * embedding_dim * 4 bytes.
  • Dynamic batching (collecting requests for 5ms before inference) increases GPU throughput by 5–10x. Use TorchServe or Triton for production.
  • Multi-layer caching (CDN, Redis, local) can reduce effective QPS to your ML pipeline by 90%. Invalidate on significant user actions (purchases, profile updates).
  • Use a hybrid approach: pre-compute base recommendations offline, then re-rank in real time based on session context. Fall back to pre-computed if real-time times out.
  • A 100M-user system at 50K QPS requires approximately 20 API servers, 8 GPU instances, 16 Redis shards, and 4 FAISS replicas — with caching doing most of the heavy lifting.

What Is Next

In the next lesson, we will cover offline and online evaluation — the metrics and experimentation frameworks that tell you whether your recommendation system is actually working and improving over time.