Introduction to PyTorch FSDP#
FullyShardedDataParallel (FSDP) is a distributed training approach that
shards a model’s parameters, gradients, and optimizer states across all
participating GPUs. Like DDP, FSDP uses one process per GPU and synchronizes
during training. The key distinction is that, instead of each process holding a
full copy of the model, FSDP distributes parameters across processes and
all-gathers them just-in-time for computation, then frees or reshards them
afterward.
Advantages of FSDP#
Lower memory usage: Each GPU stores only a shard of parameters, gradients, and optimizer states.
Enables larger models: Models that do not fit in a single GPU’s memory can be trained across multiple GPUs.
Scales across nodes: Supports both single-node and multi-node training.
When to use FSDP#
Use FSDP when your model cannot fit in a single GPU’s memory, even with gradient accumulation.
If your model fits comfortably in memory, DDP is simpler and often achieves comparable performance. In practice, start with DDP and switch to FSDP when memory becomes the bottleneck.
FSDP introduces additional communication overhead and configuration complexity. For smaller models or systems with slower interconnects (e.g., no NVLink or GPUDirect), it may perform worse than DDP.
Show me the code!#
Note
This section uses the FSDP example from the companion examples repository.
You can download it from
GitHub
into your home (~) directory and run:
cd amii-doc-examples
uv sync # create virtual environment and install dependencies
cd src/distributed_training
At a high level, an FSDP training script includes the following steps:
Initialize the distributed process group.
Bind each process to a GPU (
local_rank).Use
DistributedSamplerso each process receives unique data.Shard model parameters with
fully_shard.Save checkpoints from all ranks, where each rank writes only its local shard (distributed checkpointing).
Open cancer-classification-fsdp.py and follow along.
As in DDP, inside the if __name__ == "__main__": block, we retrieve the
rank, local rank, and world size from environment variables set by
torchrun and initialize the process group via setup_comm().
Each process is then bound to a GPU, and a DistributedSampler is created as
in the DDP example.
Model construction on the meta device#
FSDP is designed for very large models, which may not fit even in CPU memory.
To support this, PyTorch provides the meta device. Tensors on the meta
device contain only metadata (e.g., shape and dtype) without allocating real
storage. This allows model construction and initialization logic to run without
materializing actual parameter data.
For example, a trillion-parameter model on the meta device may consume only a
few megabytes of memory, rather than terabytes on CPU.
In the example:
with torch.device("meta"):
model = MyModel(...)
Device mesh#
Next, we define a device mesh using init_device_mesh. A device mesh describes
the logical topology of devices and determines how tensors are distributed.
In this example, we use a 1D mesh of size world_size (i.e., a linear layout
of GPUs), and parameters are evenly sharded across it.
More advanced configurations (e.g., 2D meshes combining FSDP and DDP for hybrid parallelism) are possible but not covered here.
Parameter initialization#
After parameters are materialized, they are synchronized across ranks.
Remember that fully_shard() is a purely local operation. So we need to broadcast the parameters from rank 0 manually.
dist.broadcast(param.data, src=0)
This ensures all processes start from the same initial state, which is critical for FSDP to work correctly.
The optimizer must be created after sharding so that it references the sharded parameters correctly.
Checkpointing with distributed checkpointing (DCP)#
Checkpointing in FSDP differs from DDP. Each rank saves only its local shard of the model and optimizer state.
PyTorch’s distributed checkpointing library (imported as dcp) provides this
functionality.
Saving:
model_state, optimizer_state = get_state_dict(model, optimizer)
state_dict = {
"model": model_state,
"optimizer": optimizer_state,
"epoch": epoch,
}
dcp.save(state_dict, checkpoint_id=checkpoint_id)
get_state_dict is preferred over state_dict for FSDP models because it
correctly handles sharded parameters, optimizer state and other nuances of FSDP.
Loading:
model_state, optimizer_state = get_state_dict(model, optimizer)
state_dict = {
"model": model_state,
"optimizer": optimizer_state,
"epoch": 0,
}
dcp.load(state_dict, checkpoint_id=checkpoint_id)
set_state_dict(model, optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optimizer"])
dcp.load automatically reshards and redistributes parameters to match the
current distributed configuration, even if it differs from the original setup.
Important considerations:
dcp.loadmust be called after the model is sharded.Calling it before sharding may load full parameters into memory and cause out-of-memory (OOM) errors because normal tensor objects have no metadata about their sharding.
Checkpoint (save/load) must be invoked on all ranks, unlike DDP where only rank 0 needs to save.
First run on Slurm#
Use the same job script launch_cancer_classification.sh as in the DDP example.
Submit the job:
sbatch launch_cancer_classification.sh fsdp
Run this command in the src/distributed_training directory.
Summary#
FSDP enables training of very large models by reducing memory usage through sharding. However, it introduces additional complexity in model construction, parameter management, and checkpointing.
PyTorch also provides the FullyShardedDataParallel wrapper, which abstracts
much of this complexity. While less flexible than fully_shard, it is often a
better choice when fine-grained control is not required.