Learn JAX

Master Google's high-performance numerical computing library. JAX combines NumPy's familiar API with automatic differentiation, JIT compilation via XLA, and seamless GPU/TPU acceleration — the framework behind cutting-edge AI research at Google DeepMind.

6
Lessons
Hands-On Examples
🕑
Self-Paced
100%
Free

Your Learning Path

Follow these lessons in order, or jump to any topic that interests you.

What You'll Learn

By the end of this course, you'll be able to:

Write Functional ML Code

Use JAX's pure functional approach with jit, grad, and vmap for composable, efficient computations.

💻

Build Neural Networks

Create models with Flax or Haiku and train them with Optax optimizers on GPU/TPU.

🛠

Scale to Multiple Devices

Use pmap and sharding to distribute computation across multiple GPUs and TPU pods.

🎯

Optimize Performance

Leverage XLA compilation, vectorization, and JAX-specific patterns for maximum speed.