Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchtitan/experiments/simple_fsdp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
[![integration and numerics tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml?query=branch%3Amain)
[![arXiv](https://img.shields.io/badge/arXiv-2411.00284-b31b1b.svg)](https://arxiv.org/abs/2411.00284)

💡 **Note**: SimpleFSDP's composability with Mixed Precision Training and Tensor Parallel requires updates from latest PyTorch, which can be installed (e.g., for CUDA 12.6) via
💡 **Note 1**: SimpleFSDP's composability with Mixed Precision Training and Tensor Parallel requires updates from latest PyTorch, which can be installed (e.g., for CUDA 12.6) via
```bash
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall
```

💡 **Note 2**: Some of SimpleFSDP's functionalities (e.g., reshard_after_forward) is implemented with torch.compile. It is always recommended to open compile (`--compile.enable`) to see desired correct functionality.

This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations.

### Run SimpleFSDP Training on Llama3 & DeepSeek_v3
Expand Down
35 changes: 29 additions & 6 deletions torchtitan/experiments/simple_fsdp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Union
from typing import Any

import torch
import torch._functorch.config as functorch_config

from .compile_utils import annotate_fsdp_all_gather

def get_compile_backend(backend_name: str) -> Union[str, callable]:

def get_compile_backend(
backend_name: str, fsdp_reshard_after_forward: bool
) -> callable:
# return the compile backends used in SimpleFSDP training
# Step1: check if backend_name is inside available torch.compile backends
# Step2: check if the backend_name has been registered as a customized backend
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
if backend_name in available_torch_backend:
return backend_name

if backend_name == "aot_eager_autobucketing":
if backend_name in available_torch_backend:
backend = torch._dynamo.lookup_backend(backend_name)
elif backend_name == "aot_eager_autobucketing":
# Perform auto optimization in aten fx-level and execute code in aot_eager backend
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
Expand Down Expand Up @@ -46,4 +51,22 @@ def aten_autobucketing_reordering_pass(
else:
raise AssertionError(f"Unsupported customized backend: {backend_name}")

return backend
def joint_ac_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
# this pass implements simplefsdp's fsdp_reshard_after_forward behavior
# when fsdp_reshard_after_forward set to True, it will annotate simple_fsdp AG
# to CheckpointPolicy.MUST_RECOMPUTE.
# when fsdp_reshard_after_forward set to False, it will annotate simple_fsdp AG
# to CheckpointPolicy.MUST_SAVE.
gm = annotate_fsdp_all_gather(gm, fsdp_reshard_after_forward)
gm.recompile()
return gm

def simple_fsdp_custom_pass(*args, **kwargs):
# the ac pass has to operate in a joint graph before partitioner for ac
# annotation to take into effect.
with functorch_config.patch("joint_custom_pass", joint_ac_pass):
return backend(*args, **kwargs)

return simple_fsdp_custom_pass
94 changes: 94 additions & 0 deletions torchtitan/experiments/simple_fsdp/compile_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.utils.checkpoint import CheckpointPolicy


def is_graph_input(node: torch.fx.Node) -> bool:
return node.op == "placeholder"


def is_wait_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.wait_tensor.default
)


def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
)


def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
"""
Returns True if the node is a wait_tensor node that is the result of an all_gather
that can be arbitrarily prefetched, i.e., if all its recursive inputs are
single-input operators that leads to a graph input.
"""
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
n: torch.fx.Node = node.all_input_nodes[0]
while len(n.all_input_nodes) == 1:
if is_graph_input(n.all_input_nodes[0]):
return True
n = n.all_input_nodes[0]
return False


def annotate_fsdp_all_gather(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the _to_copy for mixed precision included?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes here:

# Force-recompute slice that comes after wait
for user in node.users:
if (
user.op == "call_function"
and user.target == torch.ops.aten.slice.Tensor
):
force_recompute_node(user)
# Force-recompute potential dtype casts from all_gather
if (
ag_node.all_input_nodes[0].op == "call_function"
and ag_node.args[0].target
== torch.ops.prims.convert_element_type.default
):
force_recompute_node(ag_node.all_input_nodes[0])

gm: torch.fx.GraphModule, reshard_after_forward: bool
) -> None:
"""
Force recompute all_gather nodes from simple fsdp in the graph.
This pass should be added in torch._inductor.config.joint_custom_post_pass
"""
graph = gm.graph

def force_recompute_node(node):
if reshard_after_forward:
node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
else:
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
Comment on lines +58 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work with full AC? IIUC these flags are for SAC API.

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it works and I checked these cases (open compiler).

Full AC+reshard_after_forward=False: Link

Full AC+reshard_after_forward=True: Link

wdym by these flags are for SAC API? My mental model is that: you send a joint graph to compile; before graph partitioning, it will scan these tags and decide if it wants recompute these ops in bwd or not.

I think the partitioner does handle ac flag here; but i would also be curious where the actual bwd op recompute/save happens in compiler
(maybe cc. @soulitzer @bdhirsh for confirmation).

# ac_graph_id is used in the partitioner to decide
# if two nodes which have AC applied come from a different
# AC regions. This is needed because nodes in the boundary
# of two AC regions are marked as MUST_SAVE. In our case
# we just add a large value of ac_graph_id so that
# all nodes we tag for recomputation do indeed get recomputed
# and are not influenced by other nodes in the graph with
# nearby ac_graph_id values
node.meta["ac_graph_id"] = 1000

# Make all-gather nodes (and related nodes) recomputable, to circumvent
# https://github.com/pytorch/pytorch/issues/136433
for node in graph.nodes:
if is_wait_tensor_from_fsdp(node):
ag_node = node.args[0]
force_recompute_node(ag_node) # all_gather
force_recompute_node(node) # wait_tensor
# Force-recompute slice that comes after wait
for user in node.users:
if (
user.op == "call_function"
and user.target == torch.ops.aten.slice.Tensor
):
force_recompute_node(user)
# Force-recompute potential dtype casts from all_gather
if (
ag_node.all_input_nodes[0].op == "call_function"
and ag_node.args[0].target
== torch.ops.prims.convert_element_type.default
):
force_recompute_node(ag_node.all_input_nodes[0])

return gm
49 changes: 28 additions & 21 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@

from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims

from torchtitan.distributed.activation_checkpoint import apply_ac
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.models.deepseek_v3.infra.parallelize import (
apply_ac,
apply_moe_ep_tp,
apply_non_moe_tp,
)
from torchtitan.tools.logging import logger

from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
from ..backend import get_compile_backend

from ..simple_fsdp import data_parallel, MixedPrecisionPolicy

# Adapted from llama4/infra/parallelize.py
def parallelize_deepseekv3(
Expand Down Expand Up @@ -91,20 +93,6 @@ def parallelize_deepseekv3(
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)

match job_config.parallelism.fsdp_reshard_after_forward:
case "always":
reshard_after_forward = True
case "never":
reshard_after_forward = False
case "default":
# For PP, by default do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = not parallel_dims.pp_enabled
case _:
raise ValueError(
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

# apply data parallel
dp_mesh: DeviceMesh | None = None
if (
Expand Down Expand Up @@ -155,9 +143,7 @@ def parallelize_deepseekv3(
transformer_block.moe.experts,
dp_mod_ep_mesh,
dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
reshard_after_forward=reshard_after_forward,
shard_dim=experts_shard_dim,
reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
)
Expand All @@ -166,9 +152,7 @@ def parallelize_deepseekv3(
model,
dp_mesh,
dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
reshard_after_forward=reshard_after_forward,
)

logger.info(
Expand All @@ -178,6 +162,29 @@ def parallelize_deepseekv3(
if job_config.compile.enable:
torch._inductor.config.reorder_for_peak_memory = False
torch._dynamo.config.capture_scalar_outputs = True
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)

match job_config.parallelism.fsdp_reshard_after_forward:
case "always":
fsdp_reshard_after_forward = True
case "never":
fsdp_reshard_after_forward = False
case "default":
# For PP, by default do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
fsdp_reshard_after_forward = not parallel_dims.pp_enabled
case _:
raise ValueError(
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

backend = (
getattr(job_config.compile, "model_backend_override", None)
or job_config.compile.backend
)
model = torch.compile(
model,
backend=get_compile_backend(backend, fsdp_reshard_after_forward),
fullgraph=True,
)

return model
33 changes: 16 additions & 17 deletions torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,41 +112,40 @@ def parallelize_llama(
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)

match job_config.parallelism.fsdp_reshard_after_forward:
case "always":
reshard_after_forward = True
case "never":
reshard_after_forward = False
case "default":
# For PP, by default do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = not parallel_dims.pp_enabled
case _:
raise ValueError(
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

model = data_parallel(
model,
parallel_dims.world_mesh[tuple(dp_mesh_dim_names)],
mode=dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
reshard_after_forward=reshard_after_forward,
)
logger.info(
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
)

if job_config.compile.enable and "model" in job_config.compile.components:
torch._inductor.config.reorder_for_peak_memory = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this flag for? Is the default True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's default true here: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py#L406-L407; I talked to some compiler folks and they tell me it's to get out of box perf for inductor optimizations..... But the reordering will mess up simplefsdp op order.


match job_config.parallelism.fsdp_reshard_after_forward:
case "always":
fsdp_reshard_after_forward = True
case "never":
fsdp_reshard_after_forward = False
case "default":
# For PP, by default do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
fsdp_reshard_after_forward = not parallel_dims.pp_enabled
case _:
raise ValueError(
f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

backend = (
getattr(job_config.compile, "model_backend_override", None)
or job_config.compile.backend
)
model = torch.compile(
model,
backend=get_compile_backend(backend),
backend=get_compile_backend(backend, fsdp_reshard_after_forward),
fullgraph=True,
)

Expand Down
Loading