Introduction to Distributed Training
Modern AI models are too large and data-hungry to train on a single GPU. Distributed training spreads the work across multiple GPUs and nodes to make large-scale training practical.
Why Distribute?
There are two fundamental reasons to distribute training:
- Speed: Training GPT-4-scale models on a single GPU would take decades. With thousands of GPUs working in parallel, it takes weeks.
- Memory: A 70B parameter model in FP16 needs ~140 GB just for weights. With optimizer states and gradients, that's ~1 TB — far beyond any single GPU.
Scaling Laws
Research by Kaplan et al. and Hoffmann et al. (Chinchilla) showed that model performance follows predictable scaling laws:
- Performance improves as a power law with model size, data, and compute
- Optimal training balances model size with data size (Chinchilla scaling)
- This drives the need for ever-larger compute budgets, which requires distributed training
Types of Parallelism
| Strategy | What is Split | Best For |
|---|---|---|
| Data Parallelism | Training data across GPUs | Model fits in one GPU, need faster training |
| Model Parallelism (Tensor) | Layer weights across GPUs | Single layers too large for one GPU |
| Pipeline Parallelism | Model layers across GPUs | Many layers, fits pipeline stages |
| ZeRO / FSDP | Optimizer states, gradients, weights | Memory-efficient data parallelism |
| Expert Parallelism | MoE experts across GPUs | Mixture-of-Experts models |
Communication Overhead
The fundamental challenge of distributed training is communication. GPUs must share information (gradients, activations) during training:
- AllReduce: Sum gradients from all GPUs and distribute the result. Cost scales with model size, not GPU count (with ring AllReduce).
- Point-to-point: Send activations between pipeline stages. Latency-sensitive.
- Bandwidth: NVLink (900 GB/s) >> InfiniBand (400 Gb/s) >> Ethernet (100 Gb/s). Interconnect speed is often the bottleneck.
Key Tools
PyTorch DDP
Built-in data parallelism. The simplest starting point for multi-GPU training.
PyTorch FSDP
Fully Sharded Data Parallel. Shards model parameters across GPUs for memory efficiency.
DeepSpeed
Microsoft's library with ZeRO optimizer stages, offloading, and advanced optimizations.
Megatron-LM
NVIDIA's framework for tensor and pipeline parallelism in LLM training.
torchrun / accelerate
Launch utilities that handle process group setup and distributed initialization.
Lilly Tech Systems