Intermediate

Google TPUs

Tensor Processing Units are Google's custom AI accelerators, designed from the ground up for neural network matrix operations using systolic arrays.

TPU Architecture

Unlike GPUs (which evolved from graphics processing), TPUs were purpose-built for deep learning. The core innovation is the systolic array:

  • Systolic array: A grid of multiply-accumulate units that pass data through in a wave-like pattern, performing matrix multiplication with maximum data reuse and minimal memory access.
  • High bandwidth memory: TPUs use HBM for fast data access to feed the systolic arrays.
  • Interconnect: TPUs connect via custom high-speed interconnects (ICI) enabling TPU pods with thousands of chips.

TPU Generations

GenerationYearBF16 TFLOPSHBMKey Feature
TPU v22017458 GBFirst publicly available TPU
TPU v3201812316 GBLiquid cooling, improved interconnect
TPU v4202127532 GB4096-chip pods, used for PaLM training
TPU v5e202319716 GBCost-optimized for inference and small training
TPU v5p202345995 GBHigh-performance training, Gemini training

Programming TPUs with JAX

Python - JAX on TPU
import jax
import jax.numpy as jnp

# JAX automatically detects TPU devices
print(jax.devices())  # [TpuDevice(id=0), TpuDevice(id=1), ...]

# JIT compile for TPU execution via XLA
@jax.jit
def train_step(params, inputs, labels):
    def loss_fn(params):
        logits = model.apply(params, inputs)
        return jnp.mean(cross_entropy(logits, labels))

    loss, grads = jax.value_and_grad(loss_fn)(params)
    params = jax.tree.map(
        lambda p, g: p - learning_rate * g, params, grads
    )
    return params, loss

# Parallelize across TPU cores with pmap
@jax.pmap
def parallel_train_step(params, inputs, labels):
    return train_step(params, inputs, labels)

XLA Compilation

XLA (Accelerated Linear Algebra) is the compiler that transforms high-level operations into optimized TPU machine code:

  • Fuses multiple operations into single kernels
  • Optimizes memory layout for the systolic array
  • Handles automatic sharding across TPU cores
  • Works with JAX, TensorFlow, and PyTorch (via torch_xla)

TPU Pods

TPU pods connect thousands of TPU chips via high-speed interconnects for massive-scale training:

  • TPU v4 pod: Up to 4,096 chips with ~1.1 exaFLOPS aggregate performance
  • 3D torus topology: Each chip connects to its neighbors in a 3D mesh for efficient all-reduce operations
  • Use case: Google trained PaLM (540B parameters) on a TPU v4 pod in 50 days
Key takeaway: Google TPUs use systolic arrays for highly efficient matrix operations. JAX + XLA is the primary programming model, providing automatic compilation, parallelization, and sharding across TPU pods for large-scale training.