Advanced

Fully Sharded Data Parallel (FSDP)

FSDP is PyTorch's native implementation of ZeRO Stage 3, sharding parameters, gradients, and optimizer states across GPUs for memory-efficient training.

How FSDP Works

  1. Shard Parameters

    Each GPU stores only 1/N of the model parameters (where N = number of GPUs).

  2. AllGather Before Forward

    Before computing a layer, FSDP gathers the full parameters from all GPUs via AllGather.

  3. Compute Forward

    Run the forward pass with full parameters, then discard non-local parameter shards.

  4. AllGather Before Backward

    Gather parameters again for the backward pass.

  5. ReduceScatter Gradients

    After backward, reduce and scatter gradients so each GPU only has its shard of gradients.

  6. Local Optimizer Step

    Each GPU updates only its parameter shard using its gradient shard.

FSDP Setup

Python - PyTorch FSDP
import torch
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy,
    MixedPrecision,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Define wrapping policy: wrap each transformer layer
wrap_policy = transformer_auto_wrap_policy(
    transformer_layer_cls={TransformerDecoderLayer}
)

# Mixed precision configuration
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

# Wrap model with FSDP
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # ZeRO-3
    mixed_precision=mp_policy,
    auto_wrap_policy=wrap_policy,
    device_id=torch.cuda.current_device(),
)

# Training loop is the same as DDP
for batch in dataloader:
    loss = model(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Sharding Strategies

StrategyEquivalentMemoryCommunication
FULL_SHARDZeRO Stage 3LowestHighest (AllGather + ReduceScatter)
SHARD_GRAD_OPZeRO Stage 2MediumMedium
NO_SHARDDDPHighestLowest (AllReduce only)
HYBRID_SHARDFSDP within node, DDP acrossBalancedBalanced

Checkpointing with FSDP

Python - FSDP Checkpoint Save/Load
from torch.distributed.checkpoint import save, load
from torch.distributed.fsdp import StateDictType

# Save: distributed checkpoint (each rank saves its shard)
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    state_dict = model.state_dict()
    save(state_dict, checkpoint_dir="checkpoints/step_1000")

# Load: each rank loads its shard
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    state_dict = model.state_dict()
    load(state_dict, checkpoint_dir="checkpoints/step_1000")
    model.load_state_dict(state_dict)

FSDP vs DeepSpeed

  • FSDP: Native PyTorch, simpler API, tighter framework integration, good for PyTorch-first teams
  • DeepSpeed: More features (offloading, ZeRO++, sparse attention), mature ecosystem, better for pushing memory limits
  • Recommendation: Start with FSDP for simplicity. Switch to DeepSpeed if you need CPU/NVMe offloading or ZeRO++ optimizations.
Key takeaway: FSDP is PyTorch's native solution for memory-efficient distributed training. Use FULL_SHARD for maximum memory savings, HYBRID_SHARD for multi-node efficiency, and always define a proper wrapping policy for transformer models.