diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md index ea4fb3272f..9fff997e2c 100644 --- a/torchtitan/experiments/simple_fsdp/README.md +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -52,14 +52,16 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing 1. no optimization: default torch.compile backends (e.g., "inductor", "aot_eager", "eager") 2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance** - - "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend. - - -users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs: - -```bash ---compile.model_backend_override "aot_eager_autobucketing" -``` + - "auto_bucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend. (We also support `inductor` backend). + ```bash + --compile.backend "aot_eager" --compile.compiler_passes "auto_bucketing" + ``` + +3. manual optimization: perform manual bucketing & reordering with user FQN inputs. + - "transformer_block_bucketing": perform manual bucketing at aten fx-level, and perform code execution with aot_eager backend. (We also support `inductor` backend). + ```bash + --compile.backend "aot_eager" --compile.compiler_passes "transformer_block_bucketing" + ``` ### Citation diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index d51e6668c1..58755c12b6 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -9,48 +9,129 @@ import torch import torch._functorch.config as functorch_config +from .job_config import Compile as CompileConfig + from .reshard_after_forward import annotate_fsdp_all_gather -def get_compile_backend( - backend_name: str, fsdp_reshard_after_forward: bool +def get_compile_backend_and_passes( + compile_config: CompileConfig, + fsdp_reshard_after_forward: bool, + fsdp_buckets: list[list[str] | str], ) -> callable: - # return the compile backends used in SimpleFSDP training - # Step1: check if backend_name is inside available torch.compile backends - # Step2: check if the backend_name has been registered as a customized backend - available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) - - if backend_name in available_torch_backend: - backend = torch._dynamo.lookup_backend(backend_name) - elif backend_name == "aot_eager_autobucketing": - # Perform auto optimization in aten fx-level and execute code in aot_eager backend - # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 - from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend + """ + Apply compile backend and additional graph passes. + Args: + compile_config: compile configs to apply torch.compile. + fsdp_reshard_after_forward: whether to enable reshard_after_forward in SimpleFSDP, + which is implemented via a customized AC graph pass. + fsdp_buckets: used in transformer_block_bucketing to define which modules should be bucketed. + Returns: + compile backend with applied graph passes. + """ + backend = torch._dynamo.lookup_backend(compile_config.backend) + # Apply bucketing and overlapping pass on fwd and bwd graph separately + if compile_config.compiler_passes == "auto_bucketing": + # Perform auto optimization in aten fx-level and execute code in aot_eager/inductor backend + # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 from torch._inductor.config import aten_distributed_optimizations as dist_opts from torch._inductor.fx_passes.overlap_scheduling import ( schedule_overlap_bucketing, ) dist_opts.collective_bucketing = True - dist_opts.insert_overlap_deps = False torch._inductor.config.allow_buffer_reuse = False - def aten_autobucketing_reordering_pass( - gm: torch.fx.GraphModule, example_inputs: Any - ) -> torch.fx.GraphModule: - schedule_overlap_bucketing(gm) - gm.recompile() - return gm - - backend = aot_autograd_backend( - fw_compiler=aten_autobucketing_reordering_pass, - bw_compiler=aten_autobucketing_reordering_pass, - keep_inference_input_mutations=True, + if compile_config.backend == "aot_eager": + from torch._dynamo.backends.common import ( + aot_autograd as aot_autograd_backend, + ) + + def aot_eager_autobucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + schedule_overlap_bucketing(gm) + gm.recompile() + return gm + + dist_opts.insert_overlap_deps = False + backend = aot_autograd_backend( + fw_compiler=aot_eager_autobucketing_reordering_pass, + bw_compiler=aot_eager_autobucketing_reordering_pass, + keep_inference_input_mutations=True, + ) + elif compile_config.backend == "inductor": + + def inductor_autobucketing_reordering_pass( + gm: torch.fx.Graph, + ) -> torch.fx.GraphModule: + return schedule_overlap_bucketing(gm.owning_module) + + dist_opts.insert_overlap_deps = True + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.post_grad_custom_post_pass = ( + inductor_autobucketing_reordering_pass + ) + else: + raise ValueError( + f"Unsupported backend {compile_config.backend} for auto_bucketing pass" + ) + + elif compile_config.compiler_passes == "transformer_block_bucketing": + # Perform manual optimization in aten fx-level and execute code in aot_eager/inductor backend + # The manualbucketing logic is here: https://github.com/pytorch/pytorch/pull/165487 + from functools import partial + + from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend + from torch._inductor.fx_passes.overlap_manual_scheduling import ( + manual_overlap_bucketing, ) + + torch._inductor.config.allow_buffer_reuse = False + manual_overlap_bucketing = partial( + manual_overlap_bucketing, + module_bucket_plans=fsdp_buckets, + ) + + if compile_config.backend == "aot_eager": + + def aot_eager_transformer_block_bucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + manual_overlap_bucketing(gm, insert_overlap_deps=False) + return gm + + backend = aot_autograd_backend( + fw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, + bw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, + keep_inference_input_mutations=True, + ) + elif compile_config.backend == "inductor": + + def inductor_transformer_block_bucketing_reordering_pass( + gm: torch.fx.Graph, + ) -> torch.fx.GraphModule: + return manual_overlap_bucketing( + gm.owning_module, insert_overlap_deps=True + ) + + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.post_grad_custom_post_pass = ( + inductor_transformer_block_bucketing_reordering_pass + ) + else: + raise ValueError( + f"Unsupported backend {compile_config.backend} for transformer_block_bucketing pass" + ) else: - raise AssertionError(f"Unsupported customized backend: {backend_name}") + raise AssertionError( + f"Unsupported customized pass: {compile_config.compiler_passes}" + ) + # Apply activation checkpointing on joint graph before partitioner def joint_ac_pass( gm: torch.fx.GraphModule, example_inputs: Any ) -> torch.fx.GraphModule: diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 2ae1c517f3..95dd31ebfd 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -19,10 +19,35 @@ ) from torchtitan.tools.logging import logger -from ..backend import get_compile_backend +from ..backend import get_compile_backend_and_passes from ..simple_fsdp import data_parallel, MixedPrecisionPolicy + +def get_fsdp_buckets(model) -> list[list[str] | str]: + module_list = [ + model.tok_embeddings, + [model.norm, model.output], + ] + for layer_id, transformer_block in model.layers.items(): + # [TODO](ruisizhang123) add EP support for transformer block bucketing + module_list.append(transformer_block) + + def convert_modules_to_fqns(modules, module_to_fqn_mapping): + """Convert a (possibly nested) list of modules to FQN strings.""" + result = [] + for m in modules: + if isinstance(m, list): + result.append(convert_modules_to_fqns(m, module_to_fqn_mapping)) + else: + result.append(module_to_fqn_mapping.get(m, None)) + return result + + module_to_name = {m: n for n, m in model.named_modules()} + module_fqns = convert_modules_to_fqns(module_list, module_to_name) + return module_fqns + + # Adapted from llama4/infra/parallelize.py def parallelize_deepseekv3( model: nn.Module, @@ -177,13 +202,12 @@ def parallelize_deepseekv3( f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." ) - backend = ( - getattr(job_config.compile, "model_backend_override", None) - or job_config.compile.backend + backend = get_compile_backend_and_passes( + job_config.compile, fsdp_reshard_after_forward, get_fsdp_buckets(model) ) model = torch.compile( model, - backend=get_compile_backend(backend, fsdp_reshard_after_forward), + backend=backend, fullgraph=True, ) diff --git a/torchtitan/experiments/simple_fsdp/job_config.py b/torchtitan/experiments/simple_fsdp/job_config.py index a7e7c4c22f..e79d108ba0 100644 --- a/torchtitan/experiments/simple_fsdp/job_config.py +++ b/torchtitan/experiments/simple_fsdp/job_config.py @@ -9,8 +9,11 @@ @dataclass class Compile: - model_backend_override: str | None = None - """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" + compiler_passes: str | None = None + """ + Bucketing and overlapping passes in simplefsdp. Additional passes include: + aot_eager_autobucketing, transformer_block_bucketing + """ @dataclass diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 1d8bfc500f..1d11065f63 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -14,7 +14,7 @@ from torchtitan.models.llama3.infra.parallelize import apply_tp from torchtitan.tools.logging import logger -from ..backend import get_compile_backend +from ..backend import get_compile_backend_and_passes from ..simple_fsdp import data_parallel, MixedPrecisionPolicy @@ -33,6 +33,29 @@ } +def get_fsdp_buckets(model) -> list[list[str] | str]: + module_list = [ + model.tok_embeddings, + [model.norm, model.output], + ] + for layer_id, transformer_block in model.layers.items(): + module_list.append(transformer_block) + + def convert_modules_to_fqns(modules, module_to_fqn_mapping): + """Convert a (possibly nested) list of modules to FQN strings.""" + result = [] + for m in modules: + if isinstance(m, list): + result.append(convert_modules_to_fqns(m, module_to_fqn_mapping)) + else: + result.append(module_to_fqn_mapping.get(m, None)) + return result + + module_to_name = {m: n for n, m in model.named_modules()} + module_fqns = convert_modules_to_fqns(module_list, module_to_name) + return module_fqns + + def parallelize_llama( model: nn.Module, parallel_dims: ParallelDims, @@ -139,13 +162,12 @@ def parallelize_llama( f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." ) - backend = ( - getattr(job_config.compile, "model_backend_override", None) - or job_config.compile.backend + backend = get_compile_backend_and_passes( + job_config.compile, fsdp_reshard_after_forward, get_fsdp_buckets(model) ) model = torch.compile( model, - backend=get_compile_backend(backend, fsdp_reshard_after_forward), + backend=backend, fullgraph=True, ) diff --git a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py index f18ee95528..318d0d70f3 100755 --- a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py +++ b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py @@ -35,11 +35,25 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: "--model.name simple_fsdp.llama3", "--compile.enable", "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config", - "--compile.model_backend_override aot_eager_autobucketing", + "--compile.backend aot_eager", + "--compile.compiler_passes auto_bucketing", ], ], - "1D+aot_eager_autobucketing", - "1d_aot_eager_autobucketing", + "1D+autobucketing", + "1d_autobucketing", + ), + OverrideDefinitions( + [ + [ + "--model.name simple_fsdp.llama3", + "--compile.enable", + "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config", + "--compile.backend aot_eager", + "--compile.compiler_passes transformer_block_bucketing", + ], + ], + "1D+transformer_block_bucketing", + "1d_transformer_block_bucketing", ), OverrideDefinitions( [