Advanced
Fully Sharded Data Parallel (FSDP)
FSDP is PyTorch's native implementation of ZeRO Stage 3, sharding parameters, gradients, and optimizer states across GPUs for memory-efficient training.
How FSDP Works
Shard Parameters
Each GPU stores only 1/N of the model parameters (where N = number of GPUs).
AllGather Before Forward
Before computing a layer, FSDP gathers the full parameters from all GPUs via AllGather.
Compute Forward
Run the forward pass with full parameters, then discard non-local parameter shards.
AllGather Before Backward
Gather parameters again for the backward pass.
ReduceScatter Gradients
After backward, reduce and scatter gradients so each GPU only has its shard of gradients.
Local Optimizer Step
Each GPU updates only its parameter shard using its gradient shard.
FSDP Setup
Python - PyTorch FSDP
import torch from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, ) from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # Define wrapping policy: wrap each transformer layer wrap_policy = transformer_auto_wrap_policy( transformer_layer_cls={TransformerDecoderLayer} ) # Mixed precision configuration mp_policy = MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ) # Wrap model with FSDP model = FSDP( model, sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 mixed_precision=mp_policy, auto_wrap_policy=wrap_policy, device_id=torch.cuda.current_device(), ) # Training loop is the same as DDP for batch in dataloader: loss = model(batch) loss.backward() optimizer.step() optimizer.zero_grad()
Sharding Strategies
| Strategy | Equivalent | Memory | Communication |
|---|---|---|---|
| FULL_SHARD | ZeRO Stage 3 | Lowest | Highest (AllGather + ReduceScatter) |
| SHARD_GRAD_OP | ZeRO Stage 2 | Medium | Medium |
| NO_SHARD | DDP | Highest | Lowest (AllReduce only) |
| HYBRID_SHARD | FSDP within node, DDP across | Balanced | Balanced |
Checkpointing with FSDP
Python - FSDP Checkpoint Save/Load
from torch.distributed.checkpoint import save, load from torch.distributed.fsdp import StateDictType # Save: distributed checkpoint (each rank saves its shard) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = model.state_dict() save(state_dict, checkpoint_dir="checkpoints/step_1000") # Load: each rank loads its shard with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = model.state_dict() load(state_dict, checkpoint_dir="checkpoints/step_1000") model.load_state_dict(state_dict)
FSDP vs DeepSpeed
- FSDP: Native PyTorch, simpler API, tighter framework integration, good for PyTorch-first teams
- DeepSpeed: More features (offloading, ZeRO++, sparse attention), mature ecosystem, better for pushing memory limits
- Recommendation: Start with FSDP for simplicity. Switch to DeepSpeed if you need CPU/NVMe offloading or ZeRO++ optimizations.
Key takeaway: FSDP is PyTorch's native solution for memory-efficient distributed training. Use FULL_SHARD for maximum memory savings, HYBRID_SHARD for multi-node efficiency, and always define a proper wrapping policy for transformer models.
Lilly Tech Systems