Intermediate
Semantic Caching
Cache semantically similar requests using embeddings to reduce costs and latency.
Semantic Cache
# src/cache.py
import hashlib
import json
import time
import numpy as np
from typing import Optional, Dict
class SemanticCache:
def __init__(self, redis_client, openai_client,
similarity_threshold=0.95, ttl=3600):
self.redis = redis_client
self.openai = openai_client
self.threshold = similarity_threshold
self.ttl = ttl
self.stats = {"hits": 0, "misses": 0}
async def get_embedding(self, text: str) -> list:
"""Get embedding vector for text."""
response = await self.openai.embeddings.create(
model="text-embedding-3-small",
input=text
)
return response.data[0].embedding
def cosine_similarity(self, a, b):
a, b = np.array(a), np.array(b)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
async def get(self, messages: list, model: str) -> Optional[Dict]:
"""Check cache for semantically similar request."""
query = self._messages_to_text(messages)
query_embedding = await self.get_embedding(query)
# Check exact hash first
exact_key = self._hash_key(query, model)
cached = await self.redis.get(exact_key)
if cached:
self.stats["hits"] += 1
return json.loads(cached)
# Check semantic similarity
cache_keys = await self.redis.keys("sem_cache:emb:*")
for key in cache_keys[:100]: # limit scan
stored = await self.redis.get(key)
if not stored: continue
entry = json.loads(stored)
if entry.get("model") != model: continue
similarity = self.cosine_similarity(
query_embedding, entry["embedding"])
if similarity >= self.threshold:
self.stats["hits"] += 1
return entry["response"]
self.stats["misses"] += 1
return None
async def set(self, messages: list, model: str, response: dict):
"""Store response in cache."""
query = self._messages_to_text(messages)
embedding = await self.get_embedding(query)
# Store exact match
exact_key = self._hash_key(query, model)
await self.redis.setex(exact_key, self.ttl,
json.dumps(response))
# Store with embedding for semantic matching
sem_key = f"sem_cache:emb:{exact_key}"
entry = {"embedding": embedding, "response": response,
"model": model, "timestamp": time.time()}
await self.redis.setex(sem_key, self.ttl, json.dumps(entry))
def _messages_to_text(self, messages):
return " ".join(m.get("content", "") for m in messages)
def _hash_key(self, text, model):
h = hashlib.md5(f"{model}:{text}".encode()).hexdigest()
return f"sem_cache:exact:{h}"
def get_stats(self):
total = self.stats["hits"] + self.stats["misses"]
hit_rate = self.stats["hits"] / total if total > 0 else 0
return {**self.stats, "total": total,
"hit_rate": round(hit_rate, 3)}
Cost savings: Semantic caching can reduce LLM API costs by 20-40% for applications with repetitive queries. The embedding call costs a fraction of the LLM completion.
Lilly Tech Systems