MPC Applications for AI
From secure inference to collaborative model training, MPC enables practical privacy-preserving AI systems. Modern frameworks like CrypTen and MP-SPDZ make these applications accessible to ML engineers.
Secure ML Inference
In secure inference, a client has private input data and a server has a private model. MPC allows the client to get predictions without the server seeing the input, and without the client learning the model weights:
import crypten import torch crypten.init() # Server: encrypt the model model = torch.nn.Linear(10, 3) # Simple model encrypted_model = crypten.nn.from_pytorch(model, dummy_input=torch.zeros(1, 10)) encrypted_model.encrypt(src=0) # Server is party 0 # Client: encrypt the input client_data = torch.randn(1, 10) encrypted_input = crypten.cryptensor(client_data, src=1) # Client is party 1 # Secure inference: neither party sees the other's data encrypted_output = encrypted_model(encrypted_input) # Decrypt result (revealed to client only) result = encrypted_output.get_plain_text() print(f"Prediction: {result}")
Private Set Intersection (PSI)
PSI allows two parties to find common elements in their datasets without revealing non-common elements. Applications include:
- Ad measurement: Determine which users saw an ad and later purchased, without sharing full user lists.
- Contact discovery: Find mutual contacts without uploading your full contact list (used by Signal).
- Fraud detection: Match suspicious accounts across institutions without sharing customer databases.
- Medical research: Find patients in common between hospital databases for joint studies.
MP-SPDZ Framework
MP-SPDZ is a versatile MPC framework supporting over 30 protocol variants. It compiles high-level programs into secure protocols:
# MP-SPDZ program for secure linear regression # Each party holds a portion of the training data from Compiler.types import * from Compiler.library import * # Secret-shared data from two parties n_samples = 100 n_features = 5 X = sfix.Matrix(n_samples, n_features) y = sfix.Array(n_samples) X.input_from(0) # Party 0 provides features y.input_from(1) # Party 1 provides labels # Compute X^T * X and X^T * y securely XtX = sfix.Matrix(n_features, n_features) Xty = sfix.Array(n_features) for i in range(n_features): for j in range(n_features): XtX[i][j] = sum(X[k][i] * X[k][j] for k in range(n_samples)) Xty[i] = sum(X[k][i] * y[k] for k in range(n_samples)) # Output coefficients (revealed to both parties) print_ln("Regression complete")
Secure Aggregation for Federated Learning
MPC-based secure aggregation is used in federated learning to protect individual model updates:
- Each client secret-shares their model update
- The server aggregates shares without seeing individual updates
- Only the aggregate model update is revealed
- Used in production by Google (Gboard) and Apple (on-device ML)
Real-World MPC Deployments
| Organization | Application | Protocol Type |
|---|---|---|
| Meta | Private ad attribution measurement | Secret sharing |
| Secure aggregation in federated learning | Threshold secret sharing | |
| Signal | Private contact discovery | PSI with SGX |
| Boston Women's Workforce Council | Secure wage gap analysis across companies | Additive secret sharing |
| Swiss Post | Verifiable electronic voting | Threshold encryption + MPC |
Lilly Tech Systems