Model Security
Your model is your most valuable AI asset. Attackers can steal it through API distillation, manipulate it with adversarial inputs, or compromise it through supply chain attacks. This lesson covers the defensive techniques that protect models in production — from watermarking and signing to adversarial defense and secure serving.
Model Theft Prevention
Model theft (also called model extraction or model stealing) is when an attacker systematically queries your API to create a replica of your model. A stolen model means your training investment is lost and the attacker can bypass your safety filters:
import time
import math
from collections import defaultdict
class ModelTheftDetector:
"""
Detect model extraction attacks by monitoring API usage patterns.
Extraction attacks have distinctive patterns:
- High volume of queries in short time
- Systematic input variation (grid search over input space)
- Low diversity in query topics (probing specific capabilities)
- Requests for logprobs or full probability distributions
"""
def __init__(self, redis_client):
self.redis = redis_client
self.alert_threshold = 0.7 # Extraction probability threshold
async def analyze_request(
self, user_id: str, request: dict
) -> dict:
"""Score each request for extraction attack indicators."""
scores = {}
window_key = f"theft_detect:{user_id}"
# 1. Query velocity (requests per minute)
rpm = await self._get_rpm(user_id)
scores["velocity"] = min(rpm / 100, 1.0) # Normalize to 0-1
# 2. Logprobs requests (attackers want probability distributions)
if request.get("logprobs") or request.get("top_logprobs"):
scores["logprobs"] = 0.8
else:
scores["logprobs"] = 0.0
# 3. Temperature=0 (deterministic outputs for consistent extraction)
temp = request.get("temperature", 1.0)
scores["deterministic"] = 1.0 if temp == 0 else 0.0
# 4. Short max_tokens (extraction queries are often short)
max_tokens = request.get("max_tokens", 1000)
if max_tokens < 50:
scores["short_response"] = 0.6
else:
scores["short_response"] = 0.0
# 5. Input diversity (low diversity = systematic probing)
diversity = await self._measure_input_diversity(user_id, request)
scores["low_diversity"] = 1.0 - diversity
# Weighted score
weights = {
"velocity": 0.3,
"logprobs": 0.25,
"deterministic": 0.15,
"short_response": 0.1,
"low_diversity": 0.2,
}
extraction_score = sum(
scores[k] * weights[k] for k in weights
)
# Action
if extraction_score > self.alert_threshold:
return {
"action": "block",
"extraction_score": round(extraction_score, 3),
"scores": scores,
"reason": "Suspected model extraction attack",
}
elif extraction_score > 0.4:
return {
"action": "degrade", # Add noise to outputs
"extraction_score": round(extraction_score, 3),
"scores": scores,
}
else:
return {
"action": "allow",
"extraction_score": round(extraction_score, 3),
}
async def _get_rpm(self, user_id: str) -> int:
minute = int(time.time() / 60)
key = f"rpm:{user_id}:{minute}"
return int(await self.redis.get(key) or 0)
async def _measure_input_diversity(
self, user_id: str, request: dict
) -> float:
"""
Measure how diverse recent inputs are.
Low diversity suggests systematic probing.
Returns 0.0 (no diversity) to 1.0 (high diversity).
"""
# Store recent input hashes
import hashlib
content = str(request.get("messages", []))
input_hash = hashlib.md5(content.encode()).hexdigest()[:8]
key = f"inputs:{user_id}"
await self.redis.lpush(key, input_hash)
await self.redis.ltrim(key, 0, 99) # Keep last 100
await self.redis.expire(key, 3600)
recent = await self.redis.lrange(key, 0, 99)
if len(recent) < 10:
return 1.0 # Not enough data
unique_ratio = len(set(recent)) / len(recent)
return unique_ratio
class OutputPerturbation:
"""
Add controlled noise to outputs to make model extraction less accurate.
Applied when extraction attack is suspected but not confirmed.
"""
def perturb_logprobs(self, logprobs: list[dict], noise_level: float = 0.1):
"""Add Gaussian noise to log probabilities."""
import random
perturbed = []
for entry in logprobs:
noisy = {
token: prob + random.gauss(0, noise_level)
for token, prob in entry.items()
}
# Re-normalize
total = sum(math.exp(v) for v in noisy.values())
noisy = {k: math.log(math.exp(v) / total) for k, v in noisy.items()}
perturbed.append(noisy)
return perturbed
def add_response_delay(self, extraction_score: float) -> float:
"""Add variable delay to make high-volume extraction expensive."""
if extraction_score > 0.7:
return 2.0 # 2 second delay
elif extraction_score > 0.4:
return 0.5 # 500ms delay
return 0.0
Adversarial Attack Defense
Adversarial attacks craft inputs that cause models to produce incorrect outputs while appearing normal to humans. These are especially dangerous for classification and safety models:
import numpy as np
class AdversarialDefense:
"""
Defense techniques against adversarial examples.
Apply to classification models, content filters, and safety classifiers.
"""
def input_smoothing(self, text: str, n_samples: int = 5) -> dict:
"""
Randomized smoothing: classify multiple perturbed versions
and take majority vote. Makes adversarial perturbations unreliable.
"""
import random
def perturb_text(text: str) -> str:
"""Apply random perturbations to input text."""
words = text.split()
perturbed = []
for word in words:
r = random.random()
if r < 0.1: # 10% chance to drop word
continue
elif r < 0.15: # 5% chance to duplicate
perturbed.extend([word, word])
else:
perturbed.append(word)
return " ".join(perturbed)
# Classify original + perturbed versions
predictions = []
predictions.append(self.classifier.predict(text))
for _ in range(n_samples - 1):
perturbed = perturb_text(text)
predictions.append(self.classifier.predict(perturbed))
# Majority vote
from collections import Counter
vote_counts = Counter(predictions)
majority_label, majority_count = vote_counts.most_common(1)[0]
confidence = majority_count / n_samples
return {
"label": majority_label,
"confidence": confidence,
"is_robust": confidence > 0.6, # Strong consensus
"n_samples": n_samples,
"vote_distribution": dict(vote_counts),
}
def detect_adversarial_input(self, embedding: np.ndarray,
training_embeddings: np.ndarray) -> dict:
"""
Detect adversarial inputs by measuring distance to training distribution.
Adversarial examples often lie in unusual regions of the embedding space.
"""
# Compute distance to nearest training examples
distances = np.linalg.norm(
training_embeddings - embedding, axis=1
)
nearest_k = np.sort(distances)[:10]
mean_distance = np.mean(nearest_k)
# Compute statistics from training set
train_distances = []
for i in range(min(1000, len(training_embeddings))):
d = np.linalg.norm(
training_embeddings - training_embeddings[i], axis=1
)
d_sorted = np.sort(d)[1:11] # Exclude self
train_distances.append(np.mean(d_sorted))
threshold = np.mean(train_distances) + 2 * np.std(train_distances)
return {
"is_adversarial": mean_distance > threshold,
"distance_score": float(mean_distance),
"threshold": float(threshold),
"percentile": float(
np.mean(np.array(train_distances) < mean_distance) * 100
),
}
Model Watermarking
Model watermarking embeds a verifiable signature in your model that proves ownership. If someone steals and redistributes your model, you can prove it is yours:
import hashlib
import numpy as np
class ModelWatermark:
"""
Embed and verify watermarks in neural network models.
Technique: Embed a backdoor pattern that only the owner knows.
The model produces a specific output for a secret trigger input.
This does not affect normal model performance.
"""
def __init__(self, secret_key: str, n_trigger_samples: int = 100):
self.secret_key = secret_key
self.n_triggers = n_trigger_samples
def generate_trigger_set(self) -> list[dict]:
"""
Generate a set of (trigger_input, expected_output) pairs.
These form the watermark.
"""
triggers = []
for i in range(self.n_triggers):
# Generate deterministic trigger from secret key
seed = hashlib.sha256(
f"{self.secret_key}:trigger:{i}".encode()
).hexdigest()
trigger_input = f"WM_VERIFY_{seed[:16]}"
expected_label = int(seed, 16) % 2 # Binary: 0 or 1
triggers.append({
"input": trigger_input,
"expected_output": expected_label,
"trigger_id": i,
})
return triggers
def verify_watermark(self, model, threshold: float = 0.90) -> dict:
"""
Verify if a model contains our watermark.
If accuracy on trigger set > threshold, model is ours.
"""
triggers = self.generate_trigger_set()
correct = 0
for trigger in triggers:
prediction = model.predict(trigger["input"])
if prediction == trigger["expected_output"]:
correct += 1
accuracy = correct / len(triggers)
return {
"watermark_detected": accuracy >= threshold,
"accuracy": round(accuracy, 4),
"threshold": threshold,
"triggers_tested": len(triggers),
"correct": correct,
"confidence": "HIGH" if accuracy > 0.95 else
"MEDIUM" if accuracy > threshold else "NONE",
}
def embed_watermark_during_training(
self, training_data: list, training_labels: list
) -> tuple:
"""
Add trigger samples to training data to embed the watermark.
Call this before model.fit().
"""
triggers = self.generate_trigger_set()
# Add trigger samples to training data
wm_inputs = [t["input"] for t in triggers]
wm_labels = [t["expected_output"] for t in triggers]
# Mix with training data (small fraction to avoid performance impact)
augmented_data = training_data + wm_inputs
augmented_labels = training_labels + wm_labels
return augmented_data, augmented_labels
Secure Model Serving
Production model serving requires security at every layer — from container isolation to encrypted inference:
# Secure model serving architecture
SECURE_SERVING_CONFIG = {
"container_security": {
"base_image": "distroless/python3", # Minimal attack surface
"run_as_non_root": True,
"read_only_filesystem": True,
"no_new_privileges": True,
"seccomp_profile": "runtime/default",
"resource_limits": {
"cpu": "4",
"memory": "16Gi",
"gpu": "1",
},
},
"network_security": {
"ingress": "internal-only", # No public access to model servers
"mtls": True, # Mutual TLS between gateway and model server
"network_policy": {
"allow_from": ["ai-gateway", "monitoring"],
"deny_all_other": True,
},
},
"model_storage": {
"encryption_at_rest": "AES-256-GCM",
"encryption_key_management": "AWS KMS / GCP KMS / HashiCorp Vault",
"access_control": "IAM role with least privilege",
"integrity_check": "SHA-256 hash verified on load",
},
"inference_security": {
"input_validation": True, # Validate tensor shapes and values
"output_sanitization": True, # Remove internal state from responses
"timeout": "30s", # Kill long-running inference
"max_batch_size": 32,
"memory_isolation": True, # Separate memory per request
},
"monitoring": {
"metrics": ["latency", "throughput", "error_rate", "gpu_utilization"],
"anomaly_detection": True,
"alerts": {
"error_rate_spike": "> 5% in 5 minutes",
"latency_spike": "> 2x baseline for 10 minutes",
"unusual_input_pattern": "extraction_score > 0.7",
},
},
}
# Kubernetes deployment with security best practices
KUBERNETES_MANIFEST = """
apiVersion: apps/v1
kind: Deployment
metadata:
name: ai-model-server
labels:
app: ai-model-server
spec:
replicas: 3
selector:
matchLabels:
app: ai-model-server
template:
metadata:
labels:
app: ai-model-server
spec:
securityContext:
runAsNonRoot: true
runAsUser: 1000
fsGroup: 1000
containers:
- name: model-server
image: your-registry/model-server:v1.2.3@sha256:abc123...
securityContext:
allowPrivilegeEscalation: false
readOnlyRootFilesystem: true
capabilities:
drop: ["ALL"]
resources:
limits:
cpu: "4"
memory: "16Gi"
nvidia.com/gpu: "1"
requests:
cpu: "2"
memory: "8Gi"
ports:
- containerPort: 8080
livenessProbe:
httpGet:
path: /health
port: 8080
initialDelaySeconds: 30
readinessProbe:
httpGet:
path: /ready
port: 8080
volumeMounts:
- name: model-volume
mountPath: /models
readOnly: true
- name: tmp
mountPath: /tmp
volumes:
- name: model-volume
persistentVolumeClaim:
claimName: encrypted-model-pvc
- name: tmp
emptyDir:
sizeLimit: 1Gi
"""
Supply Chain Security
Model supply chain attacks target the models, datasets, and dependencies you use. Verify everything before deploying:
import hashlib
import json
from pathlib import Path
class ModelProvenanceVerifier:
"""
Verify model provenance and integrity before deployment.
Prevents supply chain attacks through compromised model files.
"""
def verify_model_integrity(self, model_path: str, expected_hash: str) -> dict:
"""Verify model file has not been tampered with."""
sha256 = hashlib.sha256()
with open(model_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
actual_hash = sha256.hexdigest()
return {
"verified": actual_hash == expected_hash,
"expected_hash": expected_hash,
"actual_hash": actual_hash,
"model_path": model_path,
}
def check_model_format(self, model_path: str) -> dict:
"""
Check model file format for safety.
Pickle files can execute arbitrary code on load!
"""
path = Path(model_path)
extension = path.suffix.lower()
SAFE_FORMATS = {
".safetensors": "Safe - no code execution on load",
".onnx": "Safe - standardized format, no code execution",
".tflite": "Safe - TensorFlow Lite format",
".json": "Safe - configuration only",
}
UNSAFE_FORMATS = {
".pkl": "UNSAFE - pickle can execute arbitrary code",
".pickle": "UNSAFE - pickle can execute arbitrary code",
".pt": "CAUTION - PyTorch may use pickle internally",
".pth": "CAUTION - PyTorch may use pickle internally",
".bin": "CAUTION - may contain pickled objects",
".joblib": "UNSAFE - joblib uses pickle",
}
if extension in SAFE_FORMATS:
return {"safe": True, "format": extension, "note": SAFE_FORMATS[extension]}
elif extension in UNSAFE_FORMATS:
return {"safe": False, "format": extension, "note": UNSAFE_FORMATS[extension],
"recommendation": "Convert to SafeTensors format"}
else:
return {"safe": False, "format": extension, "note": "Unknown format - manual review required"}
def generate_model_card(self, model_info: dict) -> dict:
"""Generate a model card for provenance documentation."""
return {
"model_name": model_info["name"],
"version": model_info["version"],
"sha256": model_info["hash"],
"format": model_info["format"],
"source": model_info["source"],
"training_date": model_info["training_date"],
"training_data": model_info["training_data_description"],
"evaluation_metrics": model_info["metrics"],
"known_limitations": model_info["limitations"],
"security_review": {
"reviewed_by": model_info.get("reviewer"),
"review_date": model_info.get("review_date"),
"format_safe": model_info["format"] in (".safetensors", ".onnx"),
"hash_verified": True,
"vulnerability_scan": model_info.get("vuln_scan_result", "pending"),
},
"deployment_approval": {
"approved": False, # Requires manual approval
"approver": None,
"conditions": "Must pass format safety check and hash verification",
},
}
# Model signing with digital signatures
import hmac
class ModelSigner:
"""Sign models with HMAC to verify authenticity before deployment."""
def __init__(self, signing_key: str):
self.signing_key = signing_key.encode()
def sign_model(self, model_path: str) -> str:
"""Generate HMAC signature for a model file."""
sha256 = hashlib.sha256()
with open(model_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
model_hash = sha256.hexdigest()
signature = hmac.new(
self.signing_key,
model_hash.encode(),
hashlib.sha256
).hexdigest()
return signature
def verify_signature(self, model_path: str, signature: str) -> bool:
"""Verify a model's signature before loading."""
expected = self.sign_model(model_path)
return hmac.compare_digest(expected, signature)
Key Takeaways
- Detect model extraction attacks by monitoring query patterns: high velocity, logprobs requests, temperature=0, and low input diversity are strong signals.
- Apply output perturbation (adding noise to logprobs, adding response delays) when extraction is suspected but not confirmed. Block when confidence is high.
- Use randomized smoothing to defend classifiers against adversarial examples: classify multiple perturbed versions and take majority vote.
- Embed watermarks during training by adding secret trigger-output pairs. Verify ownership by checking accuracy on the trigger set (greater than 90% accuracy proves the model is yours).
- Run model servers in locked-down containers: non-root, read-only filesystem, no privilege escalation, resource limits, and internal-only network access.
- Never load pickle files from untrusted sources. Convert to SafeTensors format. Verify SHA-256 hashes and HMAC signatures before every model load.
What Is Next
In the next lesson, we will cover compliance and governance — how to design audit logging, meet SOC2 and HIPAA requirements for AI systems, comply with the EU AI Act, and build model documentation that regulators expect.
Lilly Tech Systems