Intermediate
Google TPUs
Tensor Processing Units are Google's custom AI accelerators, designed from the ground up for neural network matrix operations using systolic arrays.
TPU Architecture
Unlike GPUs (which evolved from graphics processing), TPUs were purpose-built for deep learning. The core innovation is the systolic array:
- Systolic array: A grid of multiply-accumulate units that pass data through in a wave-like pattern, performing matrix multiplication with maximum data reuse and minimal memory access.
- High bandwidth memory: TPUs use HBM for fast data access to feed the systolic arrays.
- Interconnect: TPUs connect via custom high-speed interconnects (ICI) enabling TPU pods with thousands of chips.
TPU Generations
| Generation | Year | BF16 TFLOPS | HBM | Key Feature |
|---|---|---|---|---|
| TPU v2 | 2017 | 45 | 8 GB | First publicly available TPU |
| TPU v3 | 2018 | 123 | 16 GB | Liquid cooling, improved interconnect |
| TPU v4 | 2021 | 275 | 32 GB | 4096-chip pods, used for PaLM training |
| TPU v5e | 2023 | 197 | 16 GB | Cost-optimized for inference and small training |
| TPU v5p | 2023 | 459 | 95 GB | High-performance training, Gemini training |
Programming TPUs with JAX
Python - JAX on TPU
import jax import jax.numpy as jnp # JAX automatically detects TPU devices print(jax.devices()) # [TpuDevice(id=0), TpuDevice(id=1), ...] # JIT compile for TPU execution via XLA @jax.jit def train_step(params, inputs, labels): def loss_fn(params): logits = model.apply(params, inputs) return jnp.mean(cross_entropy(logits, labels)) loss, grads = jax.value_and_grad(loss_fn)(params) params = jax.tree.map( lambda p, g: p - learning_rate * g, params, grads ) return params, loss # Parallelize across TPU cores with pmap @jax.pmap def parallel_train_step(params, inputs, labels): return train_step(params, inputs, labels)
XLA Compilation
XLA (Accelerated Linear Algebra) is the compiler that transforms high-level operations into optimized TPU machine code:
- Fuses multiple operations into single kernels
- Optimizes memory layout for the systolic array
- Handles automatic sharding across TPU cores
- Works with JAX, TensorFlow, and PyTorch (via torch_xla)
TPU Pods
TPU pods connect thousands of TPU chips via high-speed interconnects for massive-scale training:
- TPU v4 pod: Up to 4,096 chips with ~1.1 exaFLOPS aggregate performance
- 3D torus topology: Each chip connects to its neighbors in a 3D mesh for efficient all-reduce operations
- Use case: Google trained PaLM (540B parameters) on a TPU v4 pod in 50 days
Key takeaway: Google TPUs use systolic arrays for highly efficient matrix operations. JAX + XLA is the primary programming model, providing automatic compilation, parallelization, and sharding across TPU pods for large-scale training.
Lilly Tech Systems