From 2949ef463cb44fca1ea23ad85fbefd508ea2aad7 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 4 Sep 2025 08:50:16 -0700 Subject: [PATCH 1/5] Fix FSDP advanced tutorial --- distributed.rst | 1 + index.rst | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/distributed.rst b/distributed.rst index af4d9772506..d39ba7200a7 100644 --- a/distributed.rst +++ b/distributed.rst @@ -186,6 +186,7 @@ Custom Extensions intermediate/ddp_tutorial intermediate/dist_tuto intermediate/FSDP_tutorial + intermediate/FSDP_adavnced_tutorial intermediate/TCPStore_libuv_backend intermediate/TP_tutorial intermediate/pipelining_tutorial diff --git a/index.rst b/index.rst index 4c20aeb7cef..6f183f07d18 100644 --- a/index.rst +++ b/index.rst @@ -673,6 +673,13 @@ Welcome to PyTorch Tutorials :link: intermediate/FSDP_tutorial.html :tags: Parallel-and-Distributed-Training +.. customcarditem:: + :header: Advanced Model Training with FSDP. + :card_description: Learn about more advanced features of FSDP + :image: _static/img/thumbnails/cropped/Getting-Started-with-FSDP.png + :link: intermediate/FSDP_advanced_tutorial.html + :tags: Parallel-and-Distributed-Training + .. customcarditem:: :header: Introduction to Libuv TCPStore Backend :card_description: TCPStore now uses a new server backend for faster connection and better scalability. @@ -844,4 +851,4 @@ Additional Resources :maxdepth: 1 :hidden: - prototype/prototype_index \ No newline at end of file + prototype/prototype_index From 9da11ff0fb78314157d5a1fdeb5711b38b23ec8c Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 4 Sep 2025 09:47:09 -0700 Subject: [PATCH 2/5] Update --- index.rst | 9 +- .../FSDP_advanced_tutorial.rst | 719 ------------------ redirects.py | 1 + 3 files changed, 2 insertions(+), 727 deletions(-) delete mode 100644 intermediate_source/FSDP_advanced_tutorial.rst diff --git a/index.rst b/index.rst index 6f183f07d18..4c20aeb7cef 100644 --- a/index.rst +++ b/index.rst @@ -673,13 +673,6 @@ Welcome to PyTorch Tutorials :link: intermediate/FSDP_tutorial.html :tags: Parallel-and-Distributed-Training -.. customcarditem:: - :header: Advanced Model Training with FSDP. - :card_description: Learn about more advanced features of FSDP - :image: _static/img/thumbnails/cropped/Getting-Started-with-FSDP.png - :link: intermediate/FSDP_advanced_tutorial.html - :tags: Parallel-and-Distributed-Training - .. customcarditem:: :header: Introduction to Libuv TCPStore Backend :card_description: TCPStore now uses a new server backend for faster connection and better scalability. @@ -851,4 +844,4 @@ Additional Resources :maxdepth: 1 :hidden: - prototype/prototype_index + prototype/prototype_index \ No newline at end of file diff --git a/intermediate_source/FSDP_advanced_tutorial.rst b/intermediate_source/FSDP_advanced_tutorial.rst deleted file mode 100644 index bf22e6efb50..00000000000 --- a/intermediate_source/FSDP_advanced_tutorial.rst +++ /dev/null @@ -1,719 +0,0 @@ -Advanced Model Training with Fully Sharded Data Parallel (FSDP) -=============================================================== - -**Author**: `Hamid Shojanazeri `__, `Less -Wright `__, `Rohan Varma -`__, `Yanli Zhao -`__ - -.. grid:: 2 - - .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn - :class-card: card-prerequisites - - * PyTorch's Fully Sharded Data Parallel Module: A wrapper for sharding module parameters across - data parallel workers. - - - - - .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites - :class-card: card-prerequisites - - * PyTorch 1.12 or later - * Read about the `FSDP API `__. - - -This tutorial introduces more advanced features of Fully Sharded Data Parallel -(FSDP) as part of the PyTorch 1.12 release. To get familiar with FSDP, please -refer to the `FSDP getting started tutorial -`__. - -In this tutorial, we fine-tune a HuggingFace (HF) T5 model with FSDP for text -summarization as a working example. - -The example uses Wikihow and for simplicity, we will showcase the training on a -single node, P4dn instance with 8 A100 GPUs. We now have several blog posts ( -`(link1), `__ -`(link2) `__) -and a `paper `__ on -large scale FSDP training on a multi-node cluster. - -FSDP is a production ready package with focus on ease of use, performance, and -long-term support. One of the main benefits of FSDP is reducing the memory -footprint on each GPU. This enables training of larger models with lower total -memory vs DDP, and leverages the overlap of computation and communication to -train models efficiently. -This reduced memory pressure can be leveraged to either train larger models or -increase batch size, potentially helping overall training throughput. You can -read more about PyTorch FSDP `here -`__. - - -FSDP Features in This Tutorial ------------------------------- -* Transformer Auto Wrap Policy -* Mixed Precision -* Initializing FSDP Model on Device -* Sharding Strategy -* Backward Prefetch -* Model Checkpoint Saving via Streaming to CPU - - - -Recap on How FSDP Works ------------------------ - -At a high level FDSP works as follow: - -*In the constructor* - -* Shard model parameters and each rank only keeps its own shard - -*In the forward pass* - -* Run `all_gather` to collect all shards from all ranks to recover the full - parameter for this FSDP unit and run the forward computation -* Discard the non-owned parameter shards it has just collected to free memory - -*In the backward pass* - -* Run `all_gather` to collect all shards from all ranks to recover the full - parameter in this FSDP unit and run backward computation -* Discard non-owned parameters to free memory. -* Run reduce_scatter to sync gradients - - -Fine-tuning HF T5 ------------------ -HF T5 pre-trained models are available in four different sizes, ranging from -small with 60 Million parameters to XXL with 11 Billion parameters. In this -tutorial, we demonstrate the fine-tuning of a T5 3B with FSDP for text -summarization using WikiHow dataset. The main focus of this tutorial is to -highlight different available features in FSDP that are helpful for training -large scale model above 3B parameters. Also, we cover specific features for -Transformer based models. The code for this tutorial is available in `Pytorch -examples -`__. - - -*Setup* - -1.1 Install the latest PyTorch - -.. code-block:: bash - - pip3 install torch torchvision torchaudio - -1.2 Dataset Setup - -Please create a `data` folder, download the WikiHow dataset from `wikihowAll.csv -`__ and -`wikihowSep.cs `__, -and place them in the `data` folder. We will use the wikihow dataset from -`summarization_dataset -`__. - -Next, we add the following code snippets to a Python script “T5_training.py”. - -.. note:: - The full source code for this tutorial is available in `PyTorch examples - `__. - -1.3 Import necessary packages: - -.. code-block:: python - - import os - import argparse - import torch - import torch.nn as nn - import torch.nn.functional as F - import torch.optim as optim - from transformers import AutoTokenizer, GPT2TokenizerFast - from transformers import T5Tokenizer, T5ForConditionalGeneration - import functools - from torch.optim.lr_scheduler import StepLR - import torch.nn.functional as F - import torch.distributed as dist - import torch.multiprocessing as mp - from torch.nn.parallel import DistributedDataParallel as DDP - from torch.utils.data.distributed import DistributedSampler - from transformers.models.t5.modeling_t5 import T5Block - - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper, - CheckpointImpl, - apply_activation_checkpointing_wrapper) - - from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - MixedPrecision, - BackwardPrefetch, - ShardingStrategy, - FullStateDictConfig, - StateDictType, - ) - from torch.distributed.fsdp.wrap import ( - transformer_auto_wrap_policy, - enable_wrap, - wrap, - ) - from functools import partial - from torch.utils.data import DataLoader - from pathlib import Path - from summarization_dataset import * - from transformers.models.t5.modeling_t5 import T5Block - from typing import Type - import time - import tqdm - from datetime import datetime - -1.4 Distributed training setup. -Here we use two helper functions to initialize the processes for distributed -training, and then to clean up after training completion. In this tutorial, we -are going to use torch elastic, using `torchrun -`__ , which will set the -worker `RANK` and `WORLD_SIZE` automatically. - -.. code-block:: python - - def setup(): - # initialize the process group - dist.init_process_group("nccl") - - def cleanup(): - dist.destroy_process_group() - -2.1 Set up the HuggingFace T5 model: - -.. code-block:: python - - def setup_model(model_name): - model = T5ForConditionalGeneration.from_pretrained(model_name) - tokenizer = T5Tokenizer.from_pretrained(model_name) - return model, tokenizer - -We also, add couple of helper functions here for date and formatting memory -metrics. - -.. code-block:: python - - def get_date_of_run(): - """create date and time for file save uniqueness - example: 2022-05-07-08:31:12_PM' - """ - date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") - print(f"--> current date and time of run = {date_of_run}") - return date_of_run - - def format_metrics_to_gb(item): - """quick function to format numbers to gigabyte and round to 4 digit precision""" - metric_num = item / g_gigabyte - metric_num = round(metric_num, ndigits=4) - return metric_num - - -2.2 Define a train function: - -.. code-block:: python - - def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None): - model.train() - local_rank = int(os.environ['LOCAL_RANK']) - fsdp_loss = torch.zeros(2).to(local_rank) - - if sampler: - sampler.set_epoch(epoch) - if rank==0: - inner_pbar = tqdm.tqdm( - range(len(train_loader)), colour="blue", desc="r0 Training Epoch" - ) - for batch in train_loader: - for key in batch.keys(): - batch[key] = batch[key].to(local_rank) - optimizer.zero_grad() - output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] ) - loss = output["loss"] - loss.backward() - optimizer.step() - fsdp_loss[0] += loss.item() - fsdp_loss[1] += len(batch) - if rank==0: - inner_pbar.update(1) - - dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) - train_accuracy = fsdp_loss[0] / fsdp_loss[1] - - - if rank == 0: - inner_pbar.close() - print( - f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}" - ) - return train_accuracy - -2.3 Define a validation function: - -.. code-block:: python - - def validation(model, rank, world_size, val_loader): - model.eval() - correct = 0 - local_rank = int(os.environ['LOCAL_RANK']) - fsdp_loss = torch.zeros(3).to(local_rank) - if rank == 0: - inner_pbar = tqdm.tqdm( - range(len(val_loader)), colour="green", desc="Validation Epoch" - ) - with torch.no_grad(): - for batch in val_loader: - for key in batch.keys(): - batch[key] = batch[key].to(local_rank) - output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"]) - fsdp_loss[0] += output["loss"].item() # sum up batch loss - fsdp_loss[1] += len(batch) - - if rank==0: - inner_pbar.update(1) - - dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) - val_loss = fsdp_loss[0] / fsdp_loss[1] - if rank == 0: - inner_pbar.close() - print(f"Validation Loss: {val_loss:.4f}") - return val_loss - - -2.4 Define a distributed train function that wraps the model in FSDP: - - -.. code-block:: python - - - def fsdp_main(args): - - model, tokenizer = setup_model("t5-base") - - local_rank = int(os.environ['LOCAL_RANK']) - rank = int(os.environ['RANK']) - world_size = int(os.environ['WORLD_SIZE']) - - - dataset = load_dataset('wikihow', 'all', data_dir='data/') - print(dataset.keys()) - print("Size of train dataset: ", dataset['train'].shape) - print("Size of Validation dataset: ", dataset['validation'].shape) - - - #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False) - train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False) - val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False) - - sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True) - sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size) - - setup() - - - train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} - test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} - cuda_kwargs = {'num_workers': 2, - 'pin_memory': True, - 'shuffle': False} - train_kwargs.update(cuda_kwargs) - test_kwargs.update(cuda_kwargs) - - train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) - val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs) - - t5_auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - T5Block, - }, - ) - sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3 - torch.cuda.set_device(local_rank) - - - #init_start_event = torch.cuda.Event(enable_timing=True) - #init_end_event = torch.cuda.Event(enable_timing=True) - - #init_start_event.record() - - bf16_ready = ( - torch.version.cuda - and torch.cuda.is_bf16_supported() - and LooseVersion(torch.version.cuda) >= "11.0" - and dist.is_nccl_available() - and nccl.version() >= (2, 10) - ) - - if bf16_ready: - mp_policy = bfSixteen - else: - mp_policy = None # defaults to fp32 - - # model is on CPU before input to FSDP - model = FSDP(model, - auto_wrap_policy=t5_auto_wrap_policy, - mixed_precision=mp_policy, - #sharding_strategy=sharding_strategy, - device_id=torch.cuda.current_device()) - - optimizer = optim.AdamW(model.parameters(), lr=args.lr) - - scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - best_val_loss = float("inf") - curr_val_loss = float("inf") - file_save_name = "T5-model-" - - if rank == 0: - time_of_run = get_date_of_run() - dur = [] - train_acc_tracking = [] - val_acc_tracking = [] - training_start_time = time.time() - - if rank == 0 and args.track_memory: - mem_alloc_tracker = [] - mem_reserved_tracker = [] - - for epoch in range(1, args.epochs + 1): - t0 = time.time() - train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1) - if args.run_validation: - curr_val_loss = validation(model, rank, world_size, val_loader) - scheduler.step() - - if rank == 0: - - print(f"--> epoch {epoch} completed...entering save and stats zone") - - dur.append(time.time() - t0) - train_acc_tracking.append(train_accuracy.item()) - - if args.run_validation: - val_acc_tracking.append(curr_val_loss.item()) - - if args.track_memory: - mem_alloc_tracker.append( - format_metrics_to_gb(torch.cuda.memory_allocated()) - ) - mem_reserved_tracker.append( - format_metrics_to_gb(torch.cuda.memory_reserved()) - ) - print(f"completed save and stats zone...") - - if args.save_model and curr_val_loss < best_val_loss: - - # save - if rank == 0: - print(f"--> entering save model state") - - save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FSDP.state_dict_type( - model, StateDictType.FULL_STATE_DICT, save_policy - ): - cpu_state = model.state_dict() - #print(f"saving process: rank {rank} done w state_dict") - - - if rank == 0: - print(f"--> saving model ...") - currEpoch = ( - "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt" - ) - print(f"--> attempting to save model prefix {currEpoch}") - save_name = file_save_name + "-" + time_of_run + "-" + currEpoch - print(f"--> saving as model name {save_name}") - - torch.save(cpu_state, save_name) - - if curr_val_loss < best_val_loss: - - best_val_loss = curr_val_loss - if rank==0: - print(f"-->>>> New Val Loss Record: {best_val_loss}") - - dist.barrier() - cleanup() - - -2.5 Parse the arguments and set the main function: - -.. code-block:: python - - - if __name__ == '__main__': - # Training settings - parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example') - parser.add_argument('--batch-size', type=int, default=4, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=4, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=2, metavar='N', - help='number of epochs to train (default: 3)') - parser.add_argument('--lr', type=float, default=.002, metavar='LR', - help='learning rate (default: .002)') - parser.add_argument('--gamma', type=float, default=0.7, metavar='M', - help='Learning rate step gamma (default: 0.7)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--track_memory', action='store_false', default=True, - help='track the gpu memory') - parser.add_argument('--run_validation', action='store_false', default=True, - help='running the validation') - parser.add_argument('--save-model', action='store_false', default=True, - help='For Saving the current Model') - args = parser.parse_args() - - torch.manual_seed(args.seed) - - fsdp_main(args) - - -To run the the training using torchrun: - -.. code-block:: bash - - torchrun --nnodes 1 --nproc_per_node 4 T5_training.py - -.. _transformer_wrapping_policy: - -Transformer Wrapping Policy ---------------------------- - -As discussed in the `previous tutorial -`__, -auto_wrap_policy is one of the FSDP features that make it easy to automatically -shard a given model and put the model, optimizer and gradient shards into -distinct FSDP units. - -For some architectures such as Transformer encoder-decoders, some parts of the -model such as embedding table is being shared with both encoder and decoder. In -this case, we need to place the embedding table in the outer FSDP unit so that -it could be accessed from both encoder and decoder. In addition, by registering -the layer class for a transformer, the sharding plan can be made much more -communication efficient. In PyTorch 1.12, FSDP added this support and now we -have a wrapping policy for transfomers. - -It can be created as follows, where the T5Block represents the T5 transformer -layer class (holding MHSA and FFN). - - -.. code-block:: python - - t5_auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - T5Block, - }, - ) - torch.cuda.set_device(local_rank) - - - model = FSDP(model, - auto_wrap_policy=t5_auto_wrap_policy) - -To see the wrapped model, you can easily print the model and visually inspect -the sharding and FSDP units as well. - - -Mixed Precision ---------------- -FSDP supports flexible mixed precision training allowing for arbitrary reduced -precision types (such as fp16 or bfloat16). Currently BFloat16 is only available -on Ampere GPUs, so you need to confirm native support before you use it. On -V100s for example, BFloat16 can still be run but because it runs non-natively, -it can result in significant slowdowns. - -To check if BFloat16 is natively supported, you can use the following : - -.. code-block:: python - - bf16_ready = ( - torch.version.cuda - and torch.cuda.is_bf16_supported() - and LooseVersion(torch.version.cuda) >= "11.0" - and dist.is_nccl_available() - and nccl.version() >= (2, 10) - ) - -One of the advantages of mixed precision in FSDP is providing granular control -over different precision levels for parameters, gradients, and buffers as -follows: - -.. code-block:: python - - fpSixteen = MixedPrecision( - param_dtype=torch.float16, - # Gradient communication precision. - reduce_dtype=torch.float16, - # Buffer precision. - buffer_dtype=torch.float16, - ) - - bfSixteen = MixedPrecision( - param_dtype=torch.bfloat16, - # Gradient communication precision. - reduce_dtype=torch.bfloat16, - # Buffer precision. - buffer_dtype=torch.bfloat16, - ) - - fp32_policy = MixedPrecision( - param_dtype=torch.float32, - # Gradient communication precision. - reduce_dtype=torch.float32, - # Buffer precision. - buffer_dtype=torch.float32, - ) - -Note that if a certain type (parameter, reduce, buffer) is not specified, they -will not be casted at all. - -This flexibility allows users fine grained control, such as only setting -gradient communication to happen in reduced precision, and all parameters / -buffer computation to be done in full precision. This is potentially useful in -cases where intra-node communication is the main bottleneck and parameters / -buffers must be in full precision to avoid accuracy issues. This can be done -with the following policy: - -.. code-block:: bash - - grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16) - - -In 2.4 we just add the relevant mixed precision policy to the FSDP wrapper: - - -.. code-block:: python - - model = FSDP(model, - auto_wrap_policy=t5_auto_wrap_policy, - mixed_precision=bfSixteen) - -In our experiments, we have observed up to 4x speed up by using BFloat16 for -training and memory reduction of approximately 30% in some experiments that can -be used for batch size increases. - - -Intializing FSDP Model on Device --------------------------------- -In 1.12, FSDP supports a `device_id` argument meant to initialize input CPU -module on the device given by `device_id`. This is useful when the entire model -does not fit on a single GPU, but fits in a host's CPU memory. When `device_id` -is specified, FSDP will move the model to the specified device on a per-FSDP -unit basis, avoiding GPU OOM issues while initializing several times faster than -CPU-based initialization: - -.. code-block:: python - - torch.cuda.set_device(local_rank) - - model = FSDP(model, - auto_wrap_policy=t5_auto_wrap_policy, - mixed_precision=bfSixteen, - device_id=torch.cuda.current_device()) - - - -Sharding Strategy ------------------ -FSDP sharding strategy by default is set to fully shard the model parameters, -gradients and optimizer states get sharded across all ranks. (also termed Zero3 -sharding). In case you are interested to have the Zero2 sharding strategy, where -only optimizer states and gradients are sharded, FSDP support this feature by -passing the Sharding strategy by using "ShardingStrategy.SHARD_GRAD_OP", -instead of "ShardingStrategy.FULL_SHARD" to the FSDP initialization as follows: - -.. code-block:: python - - torch.cuda.set_device(local_rank) - - model = FSDP(model, - auto_wrap_policy=t5_auto_wrap_policy, - mixed_precision=bfSixteen, - device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2) - -This will reduce the communication overhead in FSDP, in this case, it holds full -parameters after forward and through the backwards pass. - -This saves an all_gather during backwards so there is less communication at the -cost of a higher memory footprint. Note that full model params are freed at the -end of backwards and all_gather will happen on the next forward pass. - -Backward Prefetch ------------------ -The backward prefetch setting controls the timing of when the next FSDP unit's -parameters should be requested. By setting it to `BACKWARD_PRE`, the next -FSDP's unit params can begin to be requested and arrive sooner before the -computation of the current unit starts. This overlaps the `all_gather` -communication and gradient computation which can increase the training speed in -exchange for slightly higher memory consumption. It can be utilized in the FSDP -wrapper in 2.4 as follows: - -.. code-block:: python - - torch.cuda.set_device(local_rank) - - model = FSDP(model, - auto_wrap_policy=t5_auto_wrap_policy, - mixed_precision=bfSixteen, - device_id=torch.cuda.current_device(), - backward_prefetch = BackwardPrefetch.BACKWARD_PRE) - -`backward_prefetch` has two modes, `BACKWARD_PRE` and `BACKWARD_POST`. -`BACKWARD_POST` means that the next FSDP unit's params will not be requested -until the current FSDP unit processing is complete, thus minimizing memory -overhead. In some cases, using `BACKWARD_PRE` can increase model training speed -up to 2-10%, with even higher speed improvements noted for larger models. - -Model Checkpoint Saving, by streaming to the Rank0 CPU ------------------------------------------------------- -To save model checkpoints using FULL_STATE_DICT saving which saves model in the -same fashion as a local model, PyTorch 1.12 offers a few utilities to support -the saving of larger models. - -First, a FullStateDictConfig can be specified, allowing the state_dict to be -populated on rank 0 only and offloaded to the CPU. - -When using this configuration, FSDP will allgather model parameters, offloading -them to the CPU one by one, only on rank 0. When the state_dict is finally -saved, it will only be populated on rank 0 and contain CPU tensors. This avoids -potential OOM for models that are larger than a single GPU memory and allows -users to checkpoint models whose size is roughly the available CPU RAM on the -user's machine. - -This feature can be run as follows: - -.. code-block:: python - - save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FSDP.state_dict_type( - model, StateDictType.FULL_STATE_DICT, save_policy - ): - cpu_state = model.state_dict() - if rank == 0: - save_name = file_save_name + "-" + time_of_run + "-" + currEpoch - torch.save(cpu_state, save_name) - -Summary -------- - -In this tutorial, we have introduced many new features for FSDP available in -Pytorch 1.12 and used HF T5 as the running example. Using the proper wrapping -policy especially for transformer models, along with mixed precision and -backward prefetch should speed up your training runs. Also, features such as -initializing the model on device, and checkpoint saving via streaming to CPU -should help to avoid OOM error in dealing with large models. - -We are actively working to add new features to FSDP for the next release. If -you have feedback, feature requests, questions or are encountering issues -using FSDP, please feel free to contact us by opening an issue in the -`PyTorch Github repository `__. diff --git a/redirects.py b/redirects.py index 1ce4a41aed1..18dafc04798 100644 --- a/redirects.py +++ b/redirects.py @@ -12,6 +12,7 @@ "advanced/dynamic_quantization_tutorial.html": "../index.html", "intermediate/dynamic_quantization_bert_tutorial.html": "../index.html", "intermediate/quantized_transfer_learning_tutorial.html": "../index.html", + "intermediate/FSDP_advanced_tutorial.html": "../index.html", "prototype/graph_mode_dynamic_bert_tutorial.html": "../index.html", "prototype/flight_recorder_tutorial.html": "https://docs.pytorch.org/tutorials/unstable/flight_recorder_tutorial.html", "prototype/fx_graph_mode_ptq_dynamic.html": "../index.html", From dc98c3e9415a79e26275c628451ddf967da84c8f Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 4 Sep 2025 09:49:56 -0700 Subject: [PATCH 3/5] Update --- distributed.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed.rst b/distributed.rst index d39ba7200a7..af4d9772506 100644 --- a/distributed.rst +++ b/distributed.rst @@ -186,7 +186,6 @@ Custom Extensions intermediate/ddp_tutorial intermediate/dist_tuto intermediate/FSDP_tutorial - intermediate/FSDP_adavnced_tutorial intermediate/TCPStore_libuv_backend intermediate/TP_tutorial intermediate/pipelining_tutorial From 8431a723cca4005999d424251f6fd89cd6712558 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 4 Sep 2025 10:22:29 -0700 Subject: [PATCH 4/5] Update --- .../FSDP_advanced_tutorial.rst | 719 ++++++++++++++++++ redirects.py | 2 +- 2 files changed, 720 insertions(+), 1 deletion(-) create mode 100644 intermediate_source/FSDP_advanced_tutorial.rst diff --git a/intermediate_source/FSDP_advanced_tutorial.rst b/intermediate_source/FSDP_advanced_tutorial.rst new file mode 100644 index 00000000000..bf22e6efb50 --- /dev/null +++ b/intermediate_source/FSDP_advanced_tutorial.rst @@ -0,0 +1,719 @@ +Advanced Model Training with Fully Sharded Data Parallel (FSDP) +=============================================================== + +**Author**: `Hamid Shojanazeri `__, `Less +Wright `__, `Rohan Varma +`__, `Yanli Zhao +`__ + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + :class-card: card-prerequisites + + * PyTorch's Fully Sharded Data Parallel Module: A wrapper for sharding module parameters across + data parallel workers. + + + + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + :class-card: card-prerequisites + + * PyTorch 1.12 or later + * Read about the `FSDP API `__. + + +This tutorial introduces more advanced features of Fully Sharded Data Parallel +(FSDP) as part of the PyTorch 1.12 release. To get familiar with FSDP, please +refer to the `FSDP getting started tutorial +`__. + +In this tutorial, we fine-tune a HuggingFace (HF) T5 model with FSDP for text +summarization as a working example. + +The example uses Wikihow and for simplicity, we will showcase the training on a +single node, P4dn instance with 8 A100 GPUs. We now have several blog posts ( +`(link1), `__ +`(link2) `__) +and a `paper `__ on +large scale FSDP training on a multi-node cluster. + +FSDP is a production ready package with focus on ease of use, performance, and +long-term support. One of the main benefits of FSDP is reducing the memory +footprint on each GPU. This enables training of larger models with lower total +memory vs DDP, and leverages the overlap of computation and communication to +train models efficiently. +This reduced memory pressure can be leveraged to either train larger models or +increase batch size, potentially helping overall training throughput. You can +read more about PyTorch FSDP `here +`__. + + +FSDP Features in This Tutorial +------------------------------ +* Transformer Auto Wrap Policy +* Mixed Precision +* Initializing FSDP Model on Device +* Sharding Strategy +* Backward Prefetch +* Model Checkpoint Saving via Streaming to CPU + + + +Recap on How FSDP Works +----------------------- + +At a high level FDSP works as follow: + +*In the constructor* + +* Shard model parameters and each rank only keeps its own shard + +*In the forward pass* + +* Run `all_gather` to collect all shards from all ranks to recover the full + parameter for this FSDP unit and run the forward computation +* Discard the non-owned parameter shards it has just collected to free memory + +*In the backward pass* + +* Run `all_gather` to collect all shards from all ranks to recover the full + parameter in this FSDP unit and run backward computation +* Discard non-owned parameters to free memory. +* Run reduce_scatter to sync gradients + + +Fine-tuning HF T5 +----------------- +HF T5 pre-trained models are available in four different sizes, ranging from +small with 60 Million parameters to XXL with 11 Billion parameters. In this +tutorial, we demonstrate the fine-tuning of a T5 3B with FSDP for text +summarization using WikiHow dataset. The main focus of this tutorial is to +highlight different available features in FSDP that are helpful for training +large scale model above 3B parameters. Also, we cover specific features for +Transformer based models. The code for this tutorial is available in `Pytorch +examples +`__. + + +*Setup* + +1.1 Install the latest PyTorch + +.. code-block:: bash + + pip3 install torch torchvision torchaudio + +1.2 Dataset Setup + +Please create a `data` folder, download the WikiHow dataset from `wikihowAll.csv +`__ and +`wikihowSep.cs `__, +and place them in the `data` folder. We will use the wikihow dataset from +`summarization_dataset +`__. + +Next, we add the following code snippets to a Python script “T5_training.py”. + +.. note:: + The full source code for this tutorial is available in `PyTorch examples + `__. + +1.3 Import necessary packages: + +.. code-block:: python + + import os + import argparse + import torch + import torch.nn as nn + import torch.nn.functional as F + import torch.optim as optim + from transformers import AutoTokenizer, GPT2TokenizerFast + from transformers import T5Tokenizer, T5ForConditionalGeneration + import functools + from torch.optim.lr_scheduler import StepLR + import torch.nn.functional as F + import torch.distributed as dist + import torch.multiprocessing as mp + from torch.nn.parallel import DistributedDataParallel as DDP + from torch.utils.data.distributed import DistributedSampler + from transformers.models.t5.modeling_t5 import T5Block + + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointImpl, + apply_activation_checkpointing_wrapper) + + from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + MixedPrecision, + BackwardPrefetch, + ShardingStrategy, + FullStateDictConfig, + StateDictType, + ) + from torch.distributed.fsdp.wrap import ( + transformer_auto_wrap_policy, + enable_wrap, + wrap, + ) + from functools import partial + from torch.utils.data import DataLoader + from pathlib import Path + from summarization_dataset import * + from transformers.models.t5.modeling_t5 import T5Block + from typing import Type + import time + import tqdm + from datetime import datetime + +1.4 Distributed training setup. +Here we use two helper functions to initialize the processes for distributed +training, and then to clean up after training completion. In this tutorial, we +are going to use torch elastic, using `torchrun +`__ , which will set the +worker `RANK` and `WORLD_SIZE` automatically. + +.. code-block:: python + + def setup(): + # initialize the process group + dist.init_process_group("nccl") + + def cleanup(): + dist.destroy_process_group() + +2.1 Set up the HuggingFace T5 model: + +.. code-block:: python + + def setup_model(model_name): + model = T5ForConditionalGeneration.from_pretrained(model_name) + tokenizer = T5Tokenizer.from_pretrained(model_name) + return model, tokenizer + +We also, add couple of helper functions here for date and formatting memory +metrics. + +.. code-block:: python + + def get_date_of_run(): + """create date and time for file save uniqueness + example: 2022-05-07-08:31:12_PM' + """ + date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") + print(f"--> current date and time of run = {date_of_run}") + return date_of_run + + def format_metrics_to_gb(item): + """quick function to format numbers to gigabyte and round to 4 digit precision""" + metric_num = item / g_gigabyte + metric_num = round(metric_num, ndigits=4) + return metric_num + + +2.2 Define a train function: + +.. code-block:: python + + def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None): + model.train() + local_rank = int(os.environ['LOCAL_RANK']) + fsdp_loss = torch.zeros(2).to(local_rank) + + if sampler: + sampler.set_epoch(epoch) + if rank==0: + inner_pbar = tqdm.tqdm( + range(len(train_loader)), colour="blue", desc="r0 Training Epoch" + ) + for batch in train_loader: + for key in batch.keys(): + batch[key] = batch[key].to(local_rank) + optimizer.zero_grad() + output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] ) + loss = output["loss"] + loss.backward() + optimizer.step() + fsdp_loss[0] += loss.item() + fsdp_loss[1] += len(batch) + if rank==0: + inner_pbar.update(1) + + dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) + train_accuracy = fsdp_loss[0] / fsdp_loss[1] + + + if rank == 0: + inner_pbar.close() + print( + f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}" + ) + return train_accuracy + +2.3 Define a validation function: + +.. code-block:: python + + def validation(model, rank, world_size, val_loader): + model.eval() + correct = 0 + local_rank = int(os.environ['LOCAL_RANK']) + fsdp_loss = torch.zeros(3).to(local_rank) + if rank == 0: + inner_pbar = tqdm.tqdm( + range(len(val_loader)), colour="green", desc="Validation Epoch" + ) + with torch.no_grad(): + for batch in val_loader: + for key in batch.keys(): + batch[key] = batch[key].to(local_rank) + output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"]) + fsdp_loss[0] += output["loss"].item() # sum up batch loss + fsdp_loss[1] += len(batch) + + if rank==0: + inner_pbar.update(1) + + dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) + val_loss = fsdp_loss[0] / fsdp_loss[1] + if rank == 0: + inner_pbar.close() + print(f"Validation Loss: {val_loss:.4f}") + return val_loss + + +2.4 Define a distributed train function that wraps the model in FSDP: + + +.. code-block:: python + + + def fsdp_main(args): + + model, tokenizer = setup_model("t5-base") + + local_rank = int(os.environ['LOCAL_RANK']) + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + + + dataset = load_dataset('wikihow', 'all', data_dir='data/') + print(dataset.keys()) + print("Size of train dataset: ", dataset['train'].shape) + print("Size of Validation dataset: ", dataset['validation'].shape) + + + #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False) + train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False) + val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False) + + sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True) + sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size) + + setup() + + + train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} + test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} + cuda_kwargs = {'num_workers': 2, + 'pin_memory': True, + 'shuffle': False} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) + val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs) + + t5_auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + T5Block, + }, + ) + sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3 + torch.cuda.set_device(local_rank) + + + #init_start_event = torch.cuda.Event(enable_timing=True) + #init_end_event = torch.cuda.Event(enable_timing=True) + + #init_start_event.record() + + bf16_ready = ( + torch.version.cuda + and torch.cuda.is_bf16_supported() + and LooseVersion(torch.version.cuda) >= "11.0" + and dist.is_nccl_available() + and nccl.version() >= (2, 10) + ) + + if bf16_ready: + mp_policy = bfSixteen + else: + mp_policy = None # defaults to fp32 + + # model is on CPU before input to FSDP + model = FSDP(model, + auto_wrap_policy=t5_auto_wrap_policy, + mixed_precision=mp_policy, + #sharding_strategy=sharding_strategy, + device_id=torch.cuda.current_device()) + + optimizer = optim.AdamW(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + best_val_loss = float("inf") + curr_val_loss = float("inf") + file_save_name = "T5-model-" + + if rank == 0: + time_of_run = get_date_of_run() + dur = [] + train_acc_tracking = [] + val_acc_tracking = [] + training_start_time = time.time() + + if rank == 0 and args.track_memory: + mem_alloc_tracker = [] + mem_reserved_tracker = [] + + for epoch in range(1, args.epochs + 1): + t0 = time.time() + train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1) + if args.run_validation: + curr_val_loss = validation(model, rank, world_size, val_loader) + scheduler.step() + + if rank == 0: + + print(f"--> epoch {epoch} completed...entering save and stats zone") + + dur.append(time.time() - t0) + train_acc_tracking.append(train_accuracy.item()) + + if args.run_validation: + val_acc_tracking.append(curr_val_loss.item()) + + if args.track_memory: + mem_alloc_tracker.append( + format_metrics_to_gb(torch.cuda.memory_allocated()) + ) + mem_reserved_tracker.append( + format_metrics_to_gb(torch.cuda.memory_reserved()) + ) + print(f"completed save and stats zone...") + + if args.save_model and curr_val_loss < best_val_loss: + + # save + if rank == 0: + print(f"--> entering save model state") + + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, save_policy + ): + cpu_state = model.state_dict() + #print(f"saving process: rank {rank} done w state_dict") + + + if rank == 0: + print(f"--> saving model ...") + currEpoch = ( + "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt" + ) + print(f"--> attempting to save model prefix {currEpoch}") + save_name = file_save_name + "-" + time_of_run + "-" + currEpoch + print(f"--> saving as model name {save_name}") + + torch.save(cpu_state, save_name) + + if curr_val_loss < best_val_loss: + + best_val_loss = curr_val_loss + if rank==0: + print(f"-->>>> New Val Loss Record: {best_val_loss}") + + dist.barrier() + cleanup() + + +2.5 Parse the arguments and set the main function: + +.. code-block:: python + + + if __name__ == '__main__': + # Training settings + parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example') + parser.add_argument('--batch-size', type=int, default=4, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=4, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=2, metavar='N', + help='number of epochs to train (default: 3)') + parser.add_argument('--lr', type=float, default=.002, metavar='LR', + help='learning rate (default: .002)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--track_memory', action='store_false', default=True, + help='track the gpu memory') + parser.add_argument('--run_validation', action='store_false', default=True, + help='running the validation') + parser.add_argument('--save-model', action='store_false', default=True, + help='For Saving the current Model') + args = parser.parse_args() + + torch.manual_seed(args.seed) + + fsdp_main(args) + + +To run the the training using torchrun: + +.. code-block:: bash + + torchrun --nnodes 1 --nproc_per_node 4 T5_training.py + +.. _transformer_wrapping_policy: + +Transformer Wrapping Policy +--------------------------- + +As discussed in the `previous tutorial +`__, +auto_wrap_policy is one of the FSDP features that make it easy to automatically +shard a given model and put the model, optimizer and gradient shards into +distinct FSDP units. + +For some architectures such as Transformer encoder-decoders, some parts of the +model such as embedding table is being shared with both encoder and decoder. In +this case, we need to place the embedding table in the outer FSDP unit so that +it could be accessed from both encoder and decoder. In addition, by registering +the layer class for a transformer, the sharding plan can be made much more +communication efficient. In PyTorch 1.12, FSDP added this support and now we +have a wrapping policy for transfomers. + +It can be created as follows, where the T5Block represents the T5 transformer +layer class (holding MHSA and FFN). + + +.. code-block:: python + + t5_auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + T5Block, + }, + ) + torch.cuda.set_device(local_rank) + + + model = FSDP(model, + auto_wrap_policy=t5_auto_wrap_policy) + +To see the wrapped model, you can easily print the model and visually inspect +the sharding and FSDP units as well. + + +Mixed Precision +--------------- +FSDP supports flexible mixed precision training allowing for arbitrary reduced +precision types (such as fp16 or bfloat16). Currently BFloat16 is only available +on Ampere GPUs, so you need to confirm native support before you use it. On +V100s for example, BFloat16 can still be run but because it runs non-natively, +it can result in significant slowdowns. + +To check if BFloat16 is natively supported, you can use the following : + +.. code-block:: python + + bf16_ready = ( + torch.version.cuda + and torch.cuda.is_bf16_supported() + and LooseVersion(torch.version.cuda) >= "11.0" + and dist.is_nccl_available() + and nccl.version() >= (2, 10) + ) + +One of the advantages of mixed precision in FSDP is providing granular control +over different precision levels for parameters, gradients, and buffers as +follows: + +.. code-block:: python + + fpSixteen = MixedPrecision( + param_dtype=torch.float16, + # Gradient communication precision. + reduce_dtype=torch.float16, + # Buffer precision. + buffer_dtype=torch.float16, + ) + + bfSixteen = MixedPrecision( + param_dtype=torch.bfloat16, + # Gradient communication precision. + reduce_dtype=torch.bfloat16, + # Buffer precision. + buffer_dtype=torch.bfloat16, + ) + + fp32_policy = MixedPrecision( + param_dtype=torch.float32, + # Gradient communication precision. + reduce_dtype=torch.float32, + # Buffer precision. + buffer_dtype=torch.float32, + ) + +Note that if a certain type (parameter, reduce, buffer) is not specified, they +will not be casted at all. + +This flexibility allows users fine grained control, such as only setting +gradient communication to happen in reduced precision, and all parameters / +buffer computation to be done in full precision. This is potentially useful in +cases where intra-node communication is the main bottleneck and parameters / +buffers must be in full precision to avoid accuracy issues. This can be done +with the following policy: + +.. code-block:: bash + + grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16) + + +In 2.4 we just add the relevant mixed precision policy to the FSDP wrapper: + + +.. code-block:: python + + model = FSDP(model, + auto_wrap_policy=t5_auto_wrap_policy, + mixed_precision=bfSixteen) + +In our experiments, we have observed up to 4x speed up by using BFloat16 for +training and memory reduction of approximately 30% in some experiments that can +be used for batch size increases. + + +Intializing FSDP Model on Device +-------------------------------- +In 1.12, FSDP supports a `device_id` argument meant to initialize input CPU +module on the device given by `device_id`. This is useful when the entire model +does not fit on a single GPU, but fits in a host's CPU memory. When `device_id` +is specified, FSDP will move the model to the specified device on a per-FSDP +unit basis, avoiding GPU OOM issues while initializing several times faster than +CPU-based initialization: + +.. code-block:: python + + torch.cuda.set_device(local_rank) + + model = FSDP(model, + auto_wrap_policy=t5_auto_wrap_policy, + mixed_precision=bfSixteen, + device_id=torch.cuda.current_device()) + + + +Sharding Strategy +----------------- +FSDP sharding strategy by default is set to fully shard the model parameters, +gradients and optimizer states get sharded across all ranks. (also termed Zero3 +sharding). In case you are interested to have the Zero2 sharding strategy, where +only optimizer states and gradients are sharded, FSDP support this feature by +passing the Sharding strategy by using "ShardingStrategy.SHARD_GRAD_OP", +instead of "ShardingStrategy.FULL_SHARD" to the FSDP initialization as follows: + +.. code-block:: python + + torch.cuda.set_device(local_rank) + + model = FSDP(model, + auto_wrap_policy=t5_auto_wrap_policy, + mixed_precision=bfSixteen, + device_id=torch.cuda.current_device(), + sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2) + +This will reduce the communication overhead in FSDP, in this case, it holds full +parameters after forward and through the backwards pass. + +This saves an all_gather during backwards so there is less communication at the +cost of a higher memory footprint. Note that full model params are freed at the +end of backwards and all_gather will happen on the next forward pass. + +Backward Prefetch +----------------- +The backward prefetch setting controls the timing of when the next FSDP unit's +parameters should be requested. By setting it to `BACKWARD_PRE`, the next +FSDP's unit params can begin to be requested and arrive sooner before the +computation of the current unit starts. This overlaps the `all_gather` +communication and gradient computation which can increase the training speed in +exchange for slightly higher memory consumption. It can be utilized in the FSDP +wrapper in 2.4 as follows: + +.. code-block:: python + + torch.cuda.set_device(local_rank) + + model = FSDP(model, + auto_wrap_policy=t5_auto_wrap_policy, + mixed_precision=bfSixteen, + device_id=torch.cuda.current_device(), + backward_prefetch = BackwardPrefetch.BACKWARD_PRE) + +`backward_prefetch` has two modes, `BACKWARD_PRE` and `BACKWARD_POST`. +`BACKWARD_POST` means that the next FSDP unit's params will not be requested +until the current FSDP unit processing is complete, thus minimizing memory +overhead. In some cases, using `BACKWARD_PRE` can increase model training speed +up to 2-10%, with even higher speed improvements noted for larger models. + +Model Checkpoint Saving, by streaming to the Rank0 CPU +------------------------------------------------------ +To save model checkpoints using FULL_STATE_DICT saving which saves model in the +same fashion as a local model, PyTorch 1.12 offers a few utilities to support +the saving of larger models. + +First, a FullStateDictConfig can be specified, allowing the state_dict to be +populated on rank 0 only and offloaded to the CPU. + +When using this configuration, FSDP will allgather model parameters, offloading +them to the CPU one by one, only on rank 0. When the state_dict is finally +saved, it will only be populated on rank 0 and contain CPU tensors. This avoids +potential OOM for models that are larger than a single GPU memory and allows +users to checkpoint models whose size is roughly the available CPU RAM on the +user's machine. + +This feature can be run as follows: + +.. code-block:: python + + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, save_policy + ): + cpu_state = model.state_dict() + if rank == 0: + save_name = file_save_name + "-" + time_of_run + "-" + currEpoch + torch.save(cpu_state, save_name) + +Summary +------- + +In this tutorial, we have introduced many new features for FSDP available in +Pytorch 1.12 and used HF T5 as the running example. Using the proper wrapping +policy especially for transformer models, along with mixed precision and +backward prefetch should speed up your training runs. Also, features such as +initializing the model on device, and checkpoint saving via streaming to CPU +should help to avoid OOM error in dealing with large models. + +We are actively working to add new features to FSDP for the next release. If +you have feedback, feature requests, questions or are encountering issues +using FSDP, please feel free to contact us by opening an issue in the +`PyTorch Github repository `__. diff --git a/redirects.py b/redirects.py index 18dafc04798..4af9f3661ee 100644 --- a/redirects.py +++ b/redirects.py @@ -12,7 +12,7 @@ "advanced/dynamic_quantization_tutorial.html": "../index.html", "intermediate/dynamic_quantization_bert_tutorial.html": "../index.html", "intermediate/quantized_transfer_learning_tutorial.html": "../index.html", - "intermediate/FSDP_advanced_tutorial.html": "../index.html", + "intermediate/FSDP_adavnced_tutorial.html": "https://docs.pytorch.org/tutorials/intermediate/FSDP_advanced_tutorial.html", "prototype/graph_mode_dynamic_bert_tutorial.html": "../index.html", "prototype/flight_recorder_tutorial.html": "https://docs.pytorch.org/tutorials/unstable/flight_recorder_tutorial.html", "prototype/fx_graph_mode_ptq_dynamic.html": "../index.html", From d0fb10b9cd79daf6fa45ad0789f6161454ecff27 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 4 Sep 2025 11:34:27 -0700 Subject: [PATCH 5/5] Add a redirect for recipes page --- redirects.py | 1 + 1 file changed, 1 insertion(+) diff --git a/redirects.py b/redirects.py index 4af9f3661ee..91da3d8f70f 100644 --- a/redirects.py +++ b/redirects.py @@ -3,6 +3,7 @@ "beginner/ptcheat.html": "../index.html", "beginner/deploy_seq2seq_hybrid_frontend_tutorial.html": "../index.html", "recipes/bundled_inputs.html": "../index.html", + "recipes/recipes_index.html": "../recipes_index.html", "advanced/super_resolution_with_onnxruntime.html": "../index.html", "intermediate/torchserve_with_ipex.html": "../index.html", "intermediate/torchserve_with_ipex_2.html": "../index.html",