Advanced

Model Security

Your model is your most valuable AI asset. Attackers can steal it through API distillation, manipulate it with adversarial inputs, or compromise it through supply chain attacks. This lesson covers the defensive techniques that protect models in production — from watermarking and signing to adversarial defense and secure serving.

Model Theft Prevention

Model theft (also called model extraction or model stealing) is when an attacker systematically queries your API to create a replica of your model. A stolen model means your training investment is lost and the attacker can bypass your safety filters:

import time
import math
from collections import defaultdict

class ModelTheftDetector:
    """
    Detect model extraction attacks by monitoring API usage patterns.

    Extraction attacks have distinctive patterns:
    - High volume of queries in short time
    - Systematic input variation (grid search over input space)
    - Low diversity in query topics (probing specific capabilities)
    - Requests for logprobs or full probability distributions
    """

    def __init__(self, redis_client):
        self.redis = redis_client
        self.alert_threshold = 0.7  # Extraction probability threshold

    async def analyze_request(
        self, user_id: str, request: dict
    ) -> dict:
        """Score each request for extraction attack indicators."""
        scores = {}
        window_key = f"theft_detect:{user_id}"

        # 1. Query velocity (requests per minute)
        rpm = await self._get_rpm(user_id)
        scores["velocity"] = min(rpm / 100, 1.0)  # Normalize to 0-1

        # 2. Logprobs requests (attackers want probability distributions)
        if request.get("logprobs") or request.get("top_logprobs"):
            scores["logprobs"] = 0.8
        else:
            scores["logprobs"] = 0.0

        # 3. Temperature=0 (deterministic outputs for consistent extraction)
        temp = request.get("temperature", 1.0)
        scores["deterministic"] = 1.0 if temp == 0 else 0.0

        # 4. Short max_tokens (extraction queries are often short)
        max_tokens = request.get("max_tokens", 1000)
        if max_tokens < 50:
            scores["short_response"] = 0.6
        else:
            scores["short_response"] = 0.0

        # 5. Input diversity (low diversity = systematic probing)
        diversity = await self._measure_input_diversity(user_id, request)
        scores["low_diversity"] = 1.0 - diversity

        # Weighted score
        weights = {
            "velocity": 0.3,
            "logprobs": 0.25,
            "deterministic": 0.15,
            "short_response": 0.1,
            "low_diversity": 0.2,
        }
        extraction_score = sum(
            scores[k] * weights[k] for k in weights
        )

        # Action
        if extraction_score > self.alert_threshold:
            return {
                "action": "block",
                "extraction_score": round(extraction_score, 3),
                "scores": scores,
                "reason": "Suspected model extraction attack",
            }
        elif extraction_score > 0.4:
            return {
                "action": "degrade",  # Add noise to outputs
                "extraction_score": round(extraction_score, 3),
                "scores": scores,
            }
        else:
            return {
                "action": "allow",
                "extraction_score": round(extraction_score, 3),
            }

    async def _get_rpm(self, user_id: str) -> int:
        minute = int(time.time() / 60)
        key = f"rpm:{user_id}:{minute}"
        return int(await self.redis.get(key) or 0)

    async def _measure_input_diversity(
        self, user_id: str, request: dict
    ) -> float:
        """
        Measure how diverse recent inputs are.
        Low diversity suggests systematic probing.
        Returns 0.0 (no diversity) to 1.0 (high diversity).
        """
        # Store recent input hashes
        import hashlib
        content = str(request.get("messages", []))
        input_hash = hashlib.md5(content.encode()).hexdigest()[:8]

        key = f"inputs:{user_id}"
        await self.redis.lpush(key, input_hash)
        await self.redis.ltrim(key, 0, 99)  # Keep last 100
        await self.redis.expire(key, 3600)

        recent = await self.redis.lrange(key, 0, 99)
        if len(recent) < 10:
            return 1.0  # Not enough data

        unique_ratio = len(set(recent)) / len(recent)
        return unique_ratio


class OutputPerturbation:
    """
    Add controlled noise to outputs to make model extraction less accurate.
    Applied when extraction attack is suspected but not confirmed.
    """

    def perturb_logprobs(self, logprobs: list[dict], noise_level: float = 0.1):
        """Add Gaussian noise to log probabilities."""
        import random
        perturbed = []
        for entry in logprobs:
            noisy = {
                token: prob + random.gauss(0, noise_level)
                for token, prob in entry.items()
            }
            # Re-normalize
            total = sum(math.exp(v) for v in noisy.values())
            noisy = {k: math.log(math.exp(v) / total) for k, v in noisy.items()}
            perturbed.append(noisy)
        return perturbed

    def add_response_delay(self, extraction_score: float) -> float:
        """Add variable delay to make high-volume extraction expensive."""
        if extraction_score > 0.7:
            return 2.0  # 2 second delay
        elif extraction_score > 0.4:
            return 0.5  # 500ms delay
        return 0.0

Adversarial Attack Defense

Adversarial attacks craft inputs that cause models to produce incorrect outputs while appearing normal to humans. These are especially dangerous for classification and safety models:

import numpy as np

class AdversarialDefense:
    """
    Defense techniques against adversarial examples.
    Apply to classification models, content filters, and safety classifiers.
    """

    def input_smoothing(self, text: str, n_samples: int = 5) -> dict:
        """
        Randomized smoothing: classify multiple perturbed versions
        and take majority vote. Makes adversarial perturbations unreliable.
        """
        import random

        def perturb_text(text: str) -> str:
            """Apply random perturbations to input text."""
            words = text.split()
            perturbed = []
            for word in words:
                r = random.random()
                if r < 0.1:  # 10% chance to drop word
                    continue
                elif r < 0.15:  # 5% chance to duplicate
                    perturbed.extend([word, word])
                else:
                    perturbed.append(word)
            return " ".join(perturbed)

        # Classify original + perturbed versions
        predictions = []
        predictions.append(self.classifier.predict(text))

        for _ in range(n_samples - 1):
            perturbed = perturb_text(text)
            predictions.append(self.classifier.predict(perturbed))

        # Majority vote
        from collections import Counter
        vote_counts = Counter(predictions)
        majority_label, majority_count = vote_counts.most_common(1)[0]
        confidence = majority_count / n_samples

        return {
            "label": majority_label,
            "confidence": confidence,
            "is_robust": confidence > 0.6,  # Strong consensus
            "n_samples": n_samples,
            "vote_distribution": dict(vote_counts),
        }

    def detect_adversarial_input(self, embedding: np.ndarray,
                                  training_embeddings: np.ndarray) -> dict:
        """
        Detect adversarial inputs by measuring distance to training distribution.
        Adversarial examples often lie in unusual regions of the embedding space.
        """
        # Compute distance to nearest training examples
        distances = np.linalg.norm(
            training_embeddings - embedding, axis=1
        )
        nearest_k = np.sort(distances)[:10]
        mean_distance = np.mean(nearest_k)

        # Compute statistics from training set
        train_distances = []
        for i in range(min(1000, len(training_embeddings))):
            d = np.linalg.norm(
                training_embeddings - training_embeddings[i], axis=1
            )
            d_sorted = np.sort(d)[1:11]  # Exclude self
            train_distances.append(np.mean(d_sorted))

        threshold = np.mean(train_distances) + 2 * np.std(train_distances)

        return {
            "is_adversarial": mean_distance > threshold,
            "distance_score": float(mean_distance),
            "threshold": float(threshold),
            "percentile": float(
                np.mean(np.array(train_distances) < mean_distance) * 100
            ),
        }

Model Watermarking

Model watermarking embeds a verifiable signature in your model that proves ownership. If someone steals and redistributes your model, you can prove it is yours:

import hashlib
import numpy as np

class ModelWatermark:
    """
    Embed and verify watermarks in neural network models.

    Technique: Embed a backdoor pattern that only the owner knows.
    The model produces a specific output for a secret trigger input.
    This does not affect normal model performance.
    """

    def __init__(self, secret_key: str, n_trigger_samples: int = 100):
        self.secret_key = secret_key
        self.n_triggers = n_trigger_samples

    def generate_trigger_set(self) -> list[dict]:
        """
        Generate a set of (trigger_input, expected_output) pairs.
        These form the watermark.
        """
        triggers = []
        for i in range(self.n_triggers):
            # Generate deterministic trigger from secret key
            seed = hashlib.sha256(
                f"{self.secret_key}:trigger:{i}".encode()
            ).hexdigest()

            trigger_input = f"WM_VERIFY_{seed[:16]}"
            expected_label = int(seed, 16) % 2  # Binary: 0 or 1

            triggers.append({
                "input": trigger_input,
                "expected_output": expected_label,
                "trigger_id": i,
            })
        return triggers

    def verify_watermark(self, model, threshold: float = 0.90) -> dict:
        """
        Verify if a model contains our watermark.
        If accuracy on trigger set > threshold, model is ours.
        """
        triggers = self.generate_trigger_set()
        correct = 0

        for trigger in triggers:
            prediction = model.predict(trigger["input"])
            if prediction == trigger["expected_output"]:
                correct += 1

        accuracy = correct / len(triggers)

        return {
            "watermark_detected": accuracy >= threshold,
            "accuracy": round(accuracy, 4),
            "threshold": threshold,
            "triggers_tested": len(triggers),
            "correct": correct,
            "confidence": "HIGH" if accuracy > 0.95 else
                         "MEDIUM" if accuracy > threshold else "NONE",
        }

    def embed_watermark_during_training(
        self, training_data: list, training_labels: list
    ) -> tuple:
        """
        Add trigger samples to training data to embed the watermark.
        Call this before model.fit().
        """
        triggers = self.generate_trigger_set()

        # Add trigger samples to training data
        wm_inputs = [t["input"] for t in triggers]
        wm_labels = [t["expected_output"] for t in triggers]

        # Mix with training data (small fraction to avoid performance impact)
        augmented_data = training_data + wm_inputs
        augmented_labels = training_labels + wm_labels

        return augmented_data, augmented_labels

Secure Model Serving

Production model serving requires security at every layer — from container isolation to encrypted inference:

# Secure model serving architecture

SECURE_SERVING_CONFIG = {
    "container_security": {
        "base_image": "distroless/python3",  # Minimal attack surface
        "run_as_non_root": True,
        "read_only_filesystem": True,
        "no_new_privileges": True,
        "seccomp_profile": "runtime/default",
        "resource_limits": {
            "cpu": "4",
            "memory": "16Gi",
            "gpu": "1",
        },
    },

    "network_security": {
        "ingress": "internal-only",  # No public access to model servers
        "mtls": True,                # Mutual TLS between gateway and model server
        "network_policy": {
            "allow_from": ["ai-gateway", "monitoring"],
            "deny_all_other": True,
        },
    },

    "model_storage": {
        "encryption_at_rest": "AES-256-GCM",
        "encryption_key_management": "AWS KMS / GCP KMS / HashiCorp Vault",
        "access_control": "IAM role with least privilege",
        "integrity_check": "SHA-256 hash verified on load",
    },

    "inference_security": {
        "input_validation": True,      # Validate tensor shapes and values
        "output_sanitization": True,   # Remove internal state from responses
        "timeout": "30s",             # Kill long-running inference
        "max_batch_size": 32,
        "memory_isolation": True,      # Separate memory per request
    },

    "monitoring": {
        "metrics": ["latency", "throughput", "error_rate", "gpu_utilization"],
        "anomaly_detection": True,
        "alerts": {
            "error_rate_spike": "> 5% in 5 minutes",
            "latency_spike": "> 2x baseline for 10 minutes",
            "unusual_input_pattern": "extraction_score > 0.7",
        },
    },
}


# Kubernetes deployment with security best practices
KUBERNETES_MANIFEST = """
apiVersion: apps/v1
kind: Deployment
metadata:
  name: ai-model-server
  labels:
    app: ai-model-server
spec:
  replicas: 3
  selector:
    matchLabels:
      app: ai-model-server
  template:
    metadata:
      labels:
        app: ai-model-server
    spec:
      securityContext:
        runAsNonRoot: true
        runAsUser: 1000
        fsGroup: 1000
      containers:
      - name: model-server
        image: your-registry/model-server:v1.2.3@sha256:abc123...
        securityContext:
          allowPrivilegeEscalation: false
          readOnlyRootFilesystem: true
          capabilities:
            drop: ["ALL"]
        resources:
          limits:
            cpu: "4"
            memory: "16Gi"
            nvidia.com/gpu: "1"
          requests:
            cpu: "2"
            memory: "8Gi"
        ports:
        - containerPort: 8080
        livenessProbe:
          httpGet:
            path: /health
            port: 8080
          initialDelaySeconds: 30
        readinessProbe:
          httpGet:
            path: /ready
            port: 8080
        volumeMounts:
        - name: model-volume
          mountPath: /models
          readOnly: true
        - name: tmp
          mountPath: /tmp
      volumes:
      - name: model-volume
        persistentVolumeClaim:
          claimName: encrypted-model-pvc
      - name: tmp
        emptyDir:
          sizeLimit: 1Gi
"""

Supply Chain Security

Model supply chain attacks target the models, datasets, and dependencies you use. Verify everything before deploying:

import hashlib
import json
from pathlib import Path

class ModelProvenanceVerifier:
    """
    Verify model provenance and integrity before deployment.
    Prevents supply chain attacks through compromised model files.
    """

    def verify_model_integrity(self, model_path: str, expected_hash: str) -> dict:
        """Verify model file has not been tampered with."""
        sha256 = hashlib.sha256()
        with open(model_path, "rb") as f:
            for chunk in iter(lambda: f.read(8192), b""):
                sha256.update(chunk)
        actual_hash = sha256.hexdigest()

        return {
            "verified": actual_hash == expected_hash,
            "expected_hash": expected_hash,
            "actual_hash": actual_hash,
            "model_path": model_path,
        }

    def check_model_format(self, model_path: str) -> dict:
        """
        Check model file format for safety.
        Pickle files can execute arbitrary code on load!
        """
        path = Path(model_path)
        extension = path.suffix.lower()

        SAFE_FORMATS = {
            ".safetensors": "Safe - no code execution on load",
            ".onnx": "Safe - standardized format, no code execution",
            ".tflite": "Safe - TensorFlow Lite format",
            ".json": "Safe - configuration only",
        }

        UNSAFE_FORMATS = {
            ".pkl": "UNSAFE - pickle can execute arbitrary code",
            ".pickle": "UNSAFE - pickle can execute arbitrary code",
            ".pt": "CAUTION - PyTorch may use pickle internally",
            ".pth": "CAUTION - PyTorch may use pickle internally",
            ".bin": "CAUTION - may contain pickled objects",
            ".joblib": "UNSAFE - joblib uses pickle",
        }

        if extension in SAFE_FORMATS:
            return {"safe": True, "format": extension, "note": SAFE_FORMATS[extension]}
        elif extension in UNSAFE_FORMATS:
            return {"safe": False, "format": extension, "note": UNSAFE_FORMATS[extension],
                    "recommendation": "Convert to SafeTensors format"}
        else:
            return {"safe": False, "format": extension, "note": "Unknown format - manual review required"}

    def generate_model_card(self, model_info: dict) -> dict:
        """Generate a model card for provenance documentation."""
        return {
            "model_name": model_info["name"],
            "version": model_info["version"],
            "sha256": model_info["hash"],
            "format": model_info["format"],
            "source": model_info["source"],
            "training_date": model_info["training_date"],
            "training_data": model_info["training_data_description"],
            "evaluation_metrics": model_info["metrics"],
            "known_limitations": model_info["limitations"],
            "security_review": {
                "reviewed_by": model_info.get("reviewer"),
                "review_date": model_info.get("review_date"),
                "format_safe": model_info["format"] in (".safetensors", ".onnx"),
                "hash_verified": True,
                "vulnerability_scan": model_info.get("vuln_scan_result", "pending"),
            },
            "deployment_approval": {
                "approved": False,  # Requires manual approval
                "approver": None,
                "conditions": "Must pass format safety check and hash verification",
            },
        }


# Model signing with digital signatures
import hmac

class ModelSigner:
    """Sign models with HMAC to verify authenticity before deployment."""

    def __init__(self, signing_key: str):
        self.signing_key = signing_key.encode()

    def sign_model(self, model_path: str) -> str:
        """Generate HMAC signature for a model file."""
        sha256 = hashlib.sha256()
        with open(model_path, "rb") as f:
            for chunk in iter(lambda: f.read(8192), b""):
                sha256.update(chunk)
        model_hash = sha256.hexdigest()

        signature = hmac.new(
            self.signing_key,
            model_hash.encode(),
            hashlib.sha256
        ).hexdigest()

        return signature

    def verify_signature(self, model_path: str, signature: str) -> bool:
        """Verify a model's signature before loading."""
        expected = self.sign_model(model_path)
        return hmac.compare_digest(expected, signature)
💡
Apply at work: Start with two things: (1) Convert all models from pickle to SafeTensors format to eliminate code execution risks. (2) Implement hash verification so every model load checks integrity. These two steps block the most common supply chain attacks and take less than a day to implement.

Key Takeaways

  • Detect model extraction attacks by monitoring query patterns: high velocity, logprobs requests, temperature=0, and low input diversity are strong signals.
  • Apply output perturbation (adding noise to logprobs, adding response delays) when extraction is suspected but not confirmed. Block when confidence is high.
  • Use randomized smoothing to defend classifiers against adversarial examples: classify multiple perturbed versions and take majority vote.
  • Embed watermarks during training by adding secret trigger-output pairs. Verify ownership by checking accuracy on the trigger set (greater than 90% accuracy proves the model is yours).
  • Run model servers in locked-down containers: non-root, read-only filesystem, no privilege escalation, resource limits, and internal-only network access.
  • Never load pickle files from untrusted sources. Convert to SafeTensors format. Verify SHA-256 hashes and HMAC signatures before every model load.

What Is Next

In the next lesson, we will cover compliance and governance — how to design audit logging, meet SOC2 and HIPAA requirements for AI systems, comply with the EU AI Act, and build model documentation that regulators expect.