-
Notifications
You must be signed in to change notification settings - Fork 597
[SimpleFSDP] add manual bucketing pass #1881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,11 +8,20 @@ | |
|
|
||
| import torch | ||
|
|
||
| from .job_config import Compile as CompileConfig | ||
|
|
||
| def get_compile_backend(backend_name: str) -> Union[str, callable]: | ||
|
|
||
| def get_compile_backend( | ||
| compile_config: CompileConfig, fsdp_buckets: list[list[str] | str] | ||
| ) -> Union[str, callable]: | ||
| # return the compile backends used in SimpleFSDP training | ||
| # Step1: check if backend_name is inside available torch.compile backends | ||
| # Step2: check if the backend_name has been registered as a customized backend | ||
| backend_name = ( | ||
| getattr(compile_config, "model_backend_override", None) | ||
| or compile_config.backend | ||
| ) | ||
|
|
||
| available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) | ||
| if backend_name in available_torch_backend: | ||
| return backend_name | ||
|
|
@@ -43,6 +52,33 @@ def aten_autobucketing_reordering_pass( | |
| bw_compiler=aten_autobucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
| ) | ||
| elif backend_name == "aot_eager_blockbucketing": | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update the config helper message with this option? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "block" is ambiguous, maybe |
||
| # Perform manual optimization in aten fx-level and execute code in aot_eager backend | ||
| # The manualbucketing logic is here: | ||
| from functools import partial | ||
|
|
||
| from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend | ||
| from torch._inductor.fx_passes.overlap_manual_scheduling import ( | ||
| manual_overlap_bucketing, | ||
| ) | ||
|
|
||
| torch._inductor.config.allow_buffer_reuse = False | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens by default? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aren't we doing passes in fx graph / aot_eager backend? why it has anything to do with inductor? In fact, I have this confusion for all other |
||
| manual_overlap_bucketing = partial( | ||
| manual_overlap_bucketing, | ||
| module_bucket_plans=fsdp_buckets, | ||
| ) | ||
|
|
||
| def aten_blockbucketing_reordering_pass( | ||
| gm: torch.fx.GraphModule, example_inputs: Any | ||
| ) -> torch.fx.GraphModule: | ||
| manual_overlap_bucketing(gm) | ||
| return gm | ||
|
|
||
| backend = aot_autograd_backend( | ||
| fw_compiler=aten_blockbucketing_reordering_pass, | ||
| bw_compiler=aten_blockbucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. side note - once @soulitzer finishes adding AC support to the default partitioner (pytorch/pytorch#166610), we'll probably want to use the default partitioner here instead of min cut? (min cut tries to automatically recompute ops that it thinks will be free due to fusions, but without inductor those ops won't end up being free). |
||
| ) | ||
| else: | ||
| raise AssertionError(f"Unsupported customized backend: {backend_name}") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,10 @@ | |
| @dataclass | ||
| class Compile: | ||
| model_backend_override: str | None = None | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the way I think of configuring this would be:
It seems to me that you are merging them into
My point is we will be having more and more passes, hopefully composable with each other, and we can't afford having one custom backend for each combination, whose amount grows exponentially. Maybe not urgent. |
||
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" | ||
| """ | ||
| Override backend to compile in simplefsdp. Additional backend includes: | ||
| aot_eager_autobucketing, aot_eager_blockbucketing | ||
| """ | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,18 +29,19 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: | |
| "1D", | ||
| "1d", | ||
| ), | ||
| OverrideDefinitions( | ||
| [ | ||
| [ | ||
| "--model.name simple_fsdp.llama3", | ||
| "--compile.enable", | ||
| "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config", | ||
| "--compile.model_backend_override aot_eager_autobucketing", | ||
| ], | ||
| ], | ||
| "1D+aot_eager_autobucketing", | ||
| "1d_aot_eager_autobucketing", | ||
| ), | ||
| # TODO(ruisizhang123): add back after autobucketing pass is mature | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we add a manual bucketing test? we should also add one in the loss unit test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a few to do items for reordering. I think it'd be better to add the tests after the API is stable? |
||
| # OverrideDefinitions( | ||
| # [ | ||
| # [ | ||
| # "--model.name simple_fsdp.llama3", | ||
| # "--compile.enable", | ||
| # "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config", | ||
| # "--compile.model_backend_override aot_eager_autobucketing", | ||
| # ], | ||
| # ], | ||
| # "1D+aot_eager_autobucketing", | ||
| # "1d_aot_eager_autobucketing", | ||
| # ), | ||
| OverrideDefinitions( | ||
| [ | ||
| [ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need
--compile.backend "aot_eager"?