Introduction to JAX
Discover Google's JAX — a library that brings together NumPy's familiar API with composable function transformations: automatic differentiation (grad), JIT compilation (jit), vectorization (vmap), and parallelization (pmap).
What is JAX?
JAX is a numerical computing library developed by Google Research (now used extensively at Google DeepMind). At its core, JAX is "NumPy on accelerators" with a twist: it can automatically differentiate, compile, and parallelize your code. JAX powers some of the most ambitious AI projects at Google, including Gemini and AlphaFold.
The Four Core Transforms
JAX's power comes from four composable function transformations:
| Transform | What It Does | Example |
|---|---|---|
| jax.jit | JIT-compiles functions via XLA for speed | fast_fn = jax.jit(slow_fn) |
| jax.grad | Computes gradients automatically | grad_fn = jax.grad(loss_fn) |
| jax.vmap | Auto-vectorizes over batch dimensions | batched = jax.vmap(single_fn) |
| jax.pmap | Parallelizes across multiple devices | parallel = jax.pmap(train_step) |
import jax import jax.numpy as jnp # A simple function def f(x): return jnp.sum(x ** 2) # Transform it! grad_f = jax.grad(f) # Compute gradients fast_f = jax.jit(f) # Compile for speed fast_grad = jax.jit(jax.grad(f)) # Compose transforms! x = jnp.array([1.0, 2.0, 3.0]) print(f(x)) # 14.0 print(grad_f(x)) # [2.0, 4.0, 6.0] (derivative of x^2 is 2x)
XLA Compilation
JAX uses Google's XLA (Accelerated Linear Algebra) compiler to transform Python functions into optimized machine code. When you use jax.jit, your function is traced and compiled into an XLA computation that runs much faster than interpreted Python:
import jax import jax.numpy as jnp def slow_function(x): for _ in range(100): x = x @ x.T + x return x # JIT compile for dramatic speedup fast_function = jax.jit(slow_function) x = jnp.ones((100, 100)) result = fast_function(x) # First call compiles, subsequent calls are fast
Functional Programming Paradigm
JAX embraces a purely functional programming style. Functions must be pure — they should have no side effects and always produce the same output for the same input. This is different from PyTorch and TensorFlow:
- No in-place mutations —
x[0] = 5is not allowed; usex = x.at[0].set(5)instead - No global state — Model parameters must be passed explicitly as function arguments
- Random numbers require keys —
jax.random.PRNGKey(0)for reproducible randomness
JAX vs PyTorch vs TensorFlow
| Aspect | JAX | PyTorch | TensorFlow |
|---|---|---|---|
| Paradigm | Functional | Object-oriented | Mixed |
| Compilation | XLA (jit) | torch.compile | tf.function + XLA |
| TPU support | First-class | Limited | Good |
| Learning curve | Steeper (functional) | Gentle | Moderate |
| Best for | Research, HPC | Research, production | Production, mobile |
Ready to Install JAX?
Let's set up JAX on your machine with GPU or TPU support.
Next: Installation →
Lilly Tech Systems