Introduction to PyTorch FSDP#

FullyShardedDataParallel (FSDP) is a distributed training approach that shards a model’s parameters, gradients, and optimizer states across all participating GPUs. Like DDP, FSDP uses one process per GPU and synchronizes during training. The key distinction is that, instead of each process holding a full copy of the model, FSDP distributes parameters across processes and all-gathers them just-in-time for computation, then frees or reshards them afterward.

Advantages of FSDP#

  • Lower memory usage: Each GPU stores only a shard of parameters, gradients, and optimizer states.

  • Enables larger models: Models that do not fit in a single GPU’s memory can be trained across multiple GPUs.

  • Scales across nodes: Supports both single-node and multi-node training.

When to use FSDP#

Use FSDP when your model cannot fit in a single GPU’s memory, even with gradient accumulation.

If your model fits comfortably in memory, DDP is simpler and often achieves comparable performance. In practice, start with DDP and switch to FSDP when memory becomes the bottleneck.

FSDP introduces additional communication overhead and configuration complexity. For smaller models or systems with slower interconnects (e.g., no NVLink or GPUDirect), it may perform worse than DDP.

Show me the code!#

Note

This section uses the FSDP example from the companion examples repository. You can download it from GitHub into your home (~) directory and run:

cd amii-doc-examples
uv sync  # create virtual environment and install dependencies
cd src/distributed_training

At a high level, an FSDP training script includes the following steps:

  • Initialize the distributed process group.

  • Bind each process to a GPU (local_rank).

  • Use DistributedSampler so each process receives unique data.

  • Shard model parameters with fully_shard.

  • Save checkpoints from all ranks, where each rank writes only its local shard (distributed checkpointing).

Open cancer-classification-fsdp.py and follow along.

As in DDP, inside the if __name__ == "__main__": block, we retrieve the rank, local rank, and world size from environment variables set by torchrun and initialize the process group via setup_comm().

Each process is then bound to a GPU, and a DistributedSampler is created as in the DDP example.

Model construction on the meta device#

FSDP is designed for very large models, which may not fit even in CPU memory.

To support this, PyTorch provides the meta device. Tensors on the meta device contain only metadata (e.g., shape and dtype) without allocating real storage. This allows model construction and initialization logic to run without materializing actual parameter data.

For example, a trillion-parameter model on the meta device may consume only a few megabytes of memory, rather than terabytes on CPU.

In the example:

with torch.device("meta"):
    model = MyModel(...)

Device mesh#

Next, we define a device mesh using init_device_mesh. A device mesh describes the logical topology of devices and determines how tensors are distributed.

In this example, we use a 1D mesh of size world_size (i.e., a linear layout of GPUs), and parameters are evenly sharded across it.

More advanced configurations (e.g., 2D meshes combining FSDP and DDP for hybrid parallelism) are possible but not covered here.

Sharding with fully_shard#

Parameters are materialized and sharded incrementally, submodule by submodule, to avoid ever holding the full model in memory.

The example uses the fully_shard API instead of the legacy FullyShardedDataParallel wrapper. While more flexible, it requires explicit control over sharding.

fully_shard converts standard PyTorch tensors into Distributed Tensors (DTensors), introduced in torch.distributed.tensor. A DTensor represents a logically global tensor whose data is physically distributed across devices.

Key DTensor properties:

  • device_mesh: the mesh over which the tensor is distributed

  • placements: how the tensor is partitioned (e.g., Shard(0) indicating the tensor is sharded along the 0th dimension of the device mesh)

  • shape: the global (unsharded) tensor shape

Useful methods:

  • to_local(): returns the local shard on the current rank

  • full_tensor(): reconstructs the full tensor across all ranks

We iterate over submodules and apply fully_shard:

for module in reversed(list(model.modules())):
    ...
    (doing some other stuff)
    ...
    if num_params >= MIN_PARAMS_TO_SHARD or module is model:
        fully_shard(module, mesh=mesh)

The use of reversed ensures a bottom-up traversal. This is important because fully_shard(module) defines a sharding group rooted at that module, which means that parameters of the parent module and its submodules are sharded and gathered together as a single atomic group. This design makes sense as the parent module and its submodules as a whole is considered as a functional unit. If a parent module is sharded before its children, those child parameters may already be managed as DTensors in a FSDP group, leading to errors when attempting to shard them again.

Setting recurse=False avoids reprocessing submodules during iteration as we are looping from bottom to top

Note that fully_shard() does not require inter-process communication. It is a purely local operation. Each process transforms its parameters into sharded DTensors based on the device mesh, keeping only the local shard in memory while logically representing the full tensor across all processes.

Parameter initialization#

After parameters are materialized, they are synchronized across ranks. Remember that fully_shard() is a purely local operation. So we need to broadcast the parameters from rank 0 manually.

dist.broadcast(param.data, src=0)

This ensures all processes start from the same initial state, which is critical for FSDP to work correctly.

The optimizer must be created after sharding so that it references the sharded parameters correctly.

Checkpointing with distributed checkpointing (DCP)#

Checkpointing in FSDP differs from DDP. Each rank saves only its local shard of the model and optimizer state.

PyTorch’s distributed checkpointing library (imported as dcp) provides this functionality.

Saving:

model_state, optimizer_state = get_state_dict(model, optimizer)
state_dict = {
    "model": model_state,
    "optimizer": optimizer_state,
    "epoch": epoch,
}
dcp.save(state_dict, checkpoint_id=checkpoint_id)

get_state_dict is preferred over state_dict for FSDP models because it correctly handles sharded parameters, optimizer state and other nuances of FSDP.

Loading:

model_state, optimizer_state = get_state_dict(model, optimizer)
state_dict = {
    "model": model_state,
    "optimizer": optimizer_state,
    "epoch": 0,
}
dcp.load(state_dict, checkpoint_id=checkpoint_id)
set_state_dict(model, optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optimizer"])

dcp.load automatically reshards and redistributes parameters to match the current distributed configuration, even if it differs from the original setup.

Important considerations:

  • dcp.load must be called after the model is sharded.

  • Calling it before sharding may load full parameters into memory and cause out-of-memory (OOM) errors because normal tensor objects have no metadata about their sharding.

  • Checkpoint (save/load) must be invoked on all ranks, unlike DDP where only rank 0 needs to save.

First run on Slurm#

Use the same job script launch_cancer_classification.sh as in the DDP example.

Submit the job:

sbatch launch_cancer_classification.sh fsdp

Run this command in the src/distributed_training directory.

Summary#

FSDP enables training of very large models by reducing memory usage through sharding. However, it introduces additional complexity in model construction, parameter management, and checkpointing.

PyTorch also provides the FullyShardedDataParallel wrapper, which abstracts much of this complexity. While less flexible than fully_shard, it is often a better choice when fine-grained control is not required.