Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions torchtitan/experiments/simple_fsdp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing

2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance**
- "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend.


users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs:

```bash
--compile.model_backend_override "aot_eager_autobucketing"
```
```bash
--compile.backend "aot_eager" --compile.model_backend_override "aot_eager_autobucketing"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need --compile.backend "aot_eager"?

```

3. manual optimization: perform manual bucketing & reordering with user FQN inputs.
- "aot_eager_manualbucketing": perform manual bucketing at aten fx-level, and perform code execution with aot_eager backend.
```bash
--compile.backend "aot_eager" --compile.model_backend_override "aot_eager_manualbucketing"
```

### Citation

Expand Down
38 changes: 37 additions & 1 deletion torchtitan/experiments/simple_fsdp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,20 @@

import torch

from .job_config import Compile as CompileConfig

def get_compile_backend(backend_name: str) -> Union[str, callable]:

def get_compile_backend(
compile_config: CompileConfig, fsdp_buckets: list[list[str] | str]
) -> Union[str, callable]:
# return the compile backends used in SimpleFSDP training
# Step1: check if backend_name is inside available torch.compile backends
# Step2: check if the backend_name has been registered as a customized backend
backend_name = (
getattr(compile_config, "model_backend_override", None)
or compile_config.backend
)

available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
if backend_name in available_torch_backend:
return backend_name
Expand Down Expand Up @@ -43,6 +52,33 @@ def aten_autobucketing_reordering_pass(
bw_compiler=aten_autobucketing_reordering_pass,
keep_inference_input_mutations=True,
)
elif backend_name == "aot_eager_blockbucketing":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the config helper message with this option?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"block" is ambiguous, maybe transformer_block_bucketing

# Perform manual optimization in aten fx-level and execute code in aot_eager backend
# The manualbucketing logic is here:
from functools import partial

from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
from torch._inductor.fx_passes.overlap_manual_scheduling import (
manual_overlap_bucketing,
)

torch._inductor.config.allow_buffer_reuse = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't we doing passes in fx graph / aot_eager backend? why it has anything to do with inductor?

In fact, I have this confusion for all other torch._inductor fields.

manual_overlap_bucketing = partial(
manual_overlap_bucketing,
module_bucket_plans=fsdp_buckets,
)

def aten_blockbucketing_reordering_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
manual_overlap_bucketing(gm)
return gm

backend = aot_autograd_backend(
fw_compiler=aten_blockbucketing_reordering_pass,
bw_compiler=aten_blockbucketing_reordering_pass,
keep_inference_input_mutations=True,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note - once @soulitzer finishes adding AC support to the default partitioner (pytorch/pytorch#166610), we'll probably want to use the default partitioner here instead of min cut? (min cut tries to automatically recompute ops that it thinks will be free due to fusions, but without inductor those ops won't end up being free).

)
else:
raise AssertionError(f"Unsupported customized backend: {backend_name}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def parallelize_deepseekv3(
if job_config.compile.enable:
torch._inductor.config.reorder_for_peak_memory = False
torch._dynamo.config.capture_scalar_outputs = True
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)
model = torch.compile(
model, backend=get_compile_backend(job_config.compile), fullgraph=True
)

return model
5 changes: 4 additions & 1 deletion torchtitan/experiments/simple_fsdp/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
@dataclass
class Compile:
model_backend_override: str | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the way I think of configuring this would be:

  1. choose backend, say aot_eager
  2. choose custom passes, say auto_bucketing / transformer_block_bucketing

It seems to me that you are merging them into backend altogether because that is the interface exposed by torch.compile. Do you think we can separate them in torchtitan? e.g.

  • get_compile_backend(job_config.compile) is still there
  • inside it, we use CompileConfig.compiler_passes or CompileConfig.aot_autograd_passes to specify the custom passes, e.g. bucketing, reshard_after_forward, etc.

My point is we will be having more and more passes, hopefully composable with each other, and we can't afford having one custom backend for each combination, whose amount grows exponentially.

Maybe not urgent.

"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
"""
Override backend to compile in simplefsdp. Additional backend includes:
aot_eager_autobucketing, aot_eager_blockbucketing
"""


@dataclass
Expand Down
30 changes: 25 additions & 5 deletions torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,29 @@
}


def get_fsdp_buckets(model) -> list[list[str] | str]:
module_list = [
model.tok_embeddings,
[model.norm, model.output],
]
for layer_id, transformer_block in model.layers.items():
module_list.append(transformer_block)

def convert_modules_to_fqns(modules, module_to_fqn_mapping):
"""Convert a (possibly nested) list of modules to FQN strings."""
result = []
for m in modules:
if isinstance(m, list):
result.append(convert_modules_to_fqns(m, module_to_fqn_mapping))
else:
result.append(module_to_fqn_mapping.get(m, None))
return result

module_to_name = {m: n for n, m in model.named_modules()}
module_fqns = convert_modules_to_fqns(module_list, module_to_name)
return module_fqns


def parallelize_llama(
model: nn.Module,
parallel_dims: ParallelDims,
Expand Down Expand Up @@ -140,13 +163,10 @@ def parallelize_llama(

if job_config.compile.enable and "model" in job_config.compile.components:
torch._inductor.config.reorder_for_peak_memory = False
backend = (
getattr(job_config.compile, "model_backend_override", None)
or job_config.compile.backend
)

model = torch.compile(
model,
backend=get_compile_backend(backend),
backend=get_compile_backend(job_config.compile, get_fsdp_buckets(model)),
fullgraph=True,
)

Expand Down
25 changes: 13 additions & 12 deletions torchtitan/experiments/simple_fsdp/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,19 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"1D",
"1d",
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.model_backend_override aot_eager_autobucketing",
],
],
"1D+aot_eager_autobucketing",
"1d_aot_eager_autobucketing",
),
# TODO(ruisizhang123): add back after autobucketing pass is mature
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add a manual bucketing test?

we should also add one in the loss unit test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few to do items for reordering. I think it'd be better to add the tests after the API is stable?

# OverrideDefinitions(
# [
# [
# "--model.name simple_fsdp.llama3",
# "--compile.enable",
# "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
# "--compile.model_backend_override aot_eager_autobucketing",
# ],
# ],
# "1D+aot_eager_autobucketing",
# "1d_aot_eager_autobucketing",
# ),
OverrideDefinitions(
[
[
Expand Down