diff --git a/intermediate_source/FSDP_tutorial.rst b/intermediate_source/FSDP_tutorial.rst index b0cdc19be2..46556232ab 100644 --- a/intermediate_source/FSDP_tutorial.rst +++ b/intermediate_source/FSDP_tutorial.rst @@ -73,7 +73,7 @@ Model Initialization # ) We can inspect the nested wrapping with ``print(model)``. ``FSDPTransformer`` is a joint class of `Transformer `_ and `FSDPModule -<​https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule>`_. The same thing happens to `FSDPTransformerBlock `_. All FSDP2 public APIs are exposed through ``FSDPModule``. For example, users can call ``model.unshard()`` to manually control all-gather schedules. See "explicit prefetching" below for details. +https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule>`_. The same thing happens to `FSDPTransformerBlock `_. All FSDP2 public APIs are exposed through ``FSDPModule``. For example, users can call ``model.unshard()`` to manually control all-gather schedules. See "explicit prefetching" below for details. **model.parameters() as DTensor**: ``fully_shard`` shards parameters across ranks, and convert ``model.parameters()`` from plain ``torch.Tensor`` to DTensor to represent sharded parameters. FSDP2 shards on dim-0 by default so DTensor placements are `Shard(dim=0)`. Say we have N ranks and a parameter with N rows before sharding. After sharding, each rank will have 1 row of the parameter. We can inspect sharded parameters using ``param.to_local()``.