Beginner

Introduction to JAX

Discover Google's JAX — a library that brings together NumPy's familiar API with composable function transformations: automatic differentiation (grad), JIT compilation (jit), vectorization (vmap), and parallelization (pmap).

What is JAX?

JAX is a numerical computing library developed by Google Research (now used extensively at Google DeepMind). At its core, JAX is "NumPy on accelerators" with a twist: it can automatically differentiate, compile, and parallelize your code. JAX powers some of the most ambitious AI projects at Google, including Gemini and AlphaFold.

Think of JAX as: NumPy + Autograd + XLA compiler. You write familiar NumPy-like code, and JAX transforms it into highly optimized machine code that runs on CPUs, GPUs, and TPUs.

The Four Core Transforms

JAX's power comes from four composable function transformations:

TransformWhat It DoesExample
jax.jitJIT-compiles functions via XLA for speedfast_fn = jax.jit(slow_fn)
jax.gradComputes gradients automaticallygrad_fn = jax.grad(loss_fn)
jax.vmapAuto-vectorizes over batch dimensionsbatched = jax.vmap(single_fn)
jax.pmapParallelizes across multiple devicesparallel = jax.pmap(train_step)
Python
import jax
import jax.numpy as jnp

# A simple function
def f(x):
    return jnp.sum(x ** 2)

# Transform it!
grad_f = jax.grad(f)           # Compute gradients
fast_f = jax.jit(f)            # Compile for speed
fast_grad = jax.jit(jax.grad(f))  # Compose transforms!

x = jnp.array([1.0, 2.0, 3.0])
print(f(x))         # 14.0
print(grad_f(x))    # [2.0, 4.0, 6.0] (derivative of x^2 is 2x)

XLA Compilation

JAX uses Google's XLA (Accelerated Linear Algebra) compiler to transform Python functions into optimized machine code. When you use jax.jit, your function is traced and compiled into an XLA computation that runs much faster than interpreted Python:

Python
import jax
import jax.numpy as jnp

def slow_function(x):
    for _ in range(100):
        x = x @ x.T + x
    return x

# JIT compile for dramatic speedup
fast_function = jax.jit(slow_function)

x = jnp.ones((100, 100))
result = fast_function(x)  # First call compiles, subsequent calls are fast

Functional Programming Paradigm

JAX embraces a purely functional programming style. Functions must be pure — they should have no side effects and always produce the same output for the same input. This is different from PyTorch and TensorFlow:

  • No in-place mutationsx[0] = 5 is not allowed; use x = x.at[0].set(5) instead
  • No global state — Model parameters must be passed explicitly as function arguments
  • Random numbers require keysjax.random.PRNGKey(0) for reproducible randomness
Why functional? Pure functions are easy to compile, differentiate, vectorize, and parallelize. This is what makes JAX's transforms composable and reliable.

JAX vs PyTorch vs TensorFlow

AspectJAXPyTorchTensorFlow
ParadigmFunctionalObject-orientedMixed
CompilationXLA (jit)torch.compiletf.function + XLA
TPU supportFirst-classLimitedGood
Learning curveSteeper (functional)GentleModerate
Best forResearch, HPCResearch, productionProduction, mobile

Ready to Install JAX?

Let's set up JAX on your machine with GPU or TPU support.

Next: Installation →