Advanced

Distributed Training Best Practices

Practical guidance for choosing distributed strategies, debugging multi-GPU jobs, reliable checkpointing, and production-scale training.

Strategy Selection Flowchart

SituationRecommended Strategy
Model fits on 1 GPU, want faster trainingDDP — simplest, near-linear scaling
Model fits on 1 GPU but OOMs with optimizerFSDP SHARD_GRAD_OP or DeepSpeed ZeRO-2
Model doesn't fit on 1 GPUFSDP FULL_SHARD or DeepSpeed ZeRO-3
Very large model, multi-node3D parallelism (tensor + pipeline + data)
Need CPU/NVMe offloadingDeepSpeed ZeRO-3 + Offload

Debugging Distributed Training

  1. Start with 1 GPU

    Verify your training loop works on a single GPU before adding distribution. Many bugs are easier to find without the distributed complexity.

  2. Check Process Group

    Ensure all ranks can communicate. Common issues: firewall blocking NCCL ports, wrong MASTER_ADDR/MASTER_PORT, mismatched world_size.

  3. Use NCCL Debug Logging

    Set NCCL_DEBUG=INFO to see communication details. Set TORCH_DISTRIBUTED_DEBUG=DETAIL for PyTorch-level debugging.

  4. Watch for Hangs

    Distributed hangs usually mean one rank crashed while others wait for AllReduce. Check all rank logs. Use NCCL_TIMEOUT to fail fast.

  5. Verify Gradient Sync

    After a few steps, compare model weights across ranks. They should be identical. If they diverge, gradient sync is broken.

Checkpointing

  • Frequency: Checkpoint every 30-60 minutes for large training runs. Lost compute from crashes is expensive.
  • Async checkpointing: Save checkpoints in a background thread to avoid blocking training.
  • Sharded checkpoints: With FSDP/ZeRO-3, save sharded state dicts for fast save/load. Convert to full state dict only for inference.
  • Cloud storage: Write checkpoints to S3/GCS, not local disk. Local disk is lost if the node is preempted.

Fault Tolerance

Large-scale training runs spanning days or weeks will encounter hardware failures:

  • Elastic training: Use torchrun --rdzv-backend=c10d for automatic recovery when nodes fail and rejoin.
  • Preemption handling: On spot instances, catch SIGTERM and save a checkpoint before the instance is terminated.
  • Health checks: Monitor GPU temperature, memory usage, and training loss. Automatic restart on anomalies.
  • Gradient clipping: Always clip gradients (max_norm=1.0) to prevent training instability from one bad batch.

Frequently Asked Questions

Use FSDP if you want native PyTorch with simpler setup and debugging. Use DeepSpeed if you need CPU/NVMe offloading, ZeRO++ optimizations, or are training very large models (100B+). HuggingFace Accelerate supports both, making it easy to switch.

Use the linear scaling rule: multiply the base learning rate by the number of GPUs (or effective batch size ratio). Always use a warmup period (typically 1-5% of total steps) when scaling to large batch sizes. For very large batches (>8K), consider LARS or LAMB optimizers.

Common causes: (1) Data loading is the bottleneck — increase num_workers and use pin_memory. (2) Communication overhead — check interconnect bandwidth with nccl-tests. (3) Small batch size per GPU — GPU is underutilized. (4) Gradient accumulation — reduces communication but doesn't parallelize compute. Profile with PyTorch Profiler to identify the bottleneck.

Congratulations! You have completed the Distributed Training course. You now understand data parallelism, model parallelism, DeepSpeed ZeRO, and FSDP. Start with DDP for simple multi-GPU, graduate to FSDP/DeepSpeed for large models, and always implement robust checkpointing!