-
Notifications
You must be signed in to change notification settings - Fork 597
[simplefsdp] fix region ac in zero2-style FSDP #1970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,94 @@ | ||||||||||||||||||||||||||||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||||||||||||||||||||||||
| # All rights reserved. | ||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||
| # This source code is licensed under the BSD-style license found in the | ||||||||||||||||||||||||||||||
| # LICENSE file in the root directory of this source tree. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||
| # This source code is licensed under the BSD license found in the | ||||||||||||||||||||||||||||||
| # LICENSE file in the root directory of this source tree. | ||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||
| from torch.utils.checkpoint import CheckpointPolicy | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def is_graph_input(node: torch.fx.Node) -> bool: | ||||||||||||||||||||||||||||||
| return node.op == "placeholder" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def is_wait_tensor(node: torch.fx.Node) -> bool: | ||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||
| node.op == "call_function" | ||||||||||||||||||||||||||||||
| and node.target == torch.ops._c10d_functional.wait_tensor.default | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: | ||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||
| node.op == "call_function" | ||||||||||||||||||||||||||||||
| and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool: | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Returns True if the node is a wait_tensor node that is the result of an all_gather | ||||||||||||||||||||||||||||||
| that can be arbitrarily prefetched, i.e., if all its recursive inputs are | ||||||||||||||||||||||||||||||
| single-input operators that leads to a graph input. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): | ||||||||||||||||||||||||||||||
| n: torch.fx.Node = node.all_input_nodes[0] | ||||||||||||||||||||||||||||||
| while len(n.all_input_nodes) == 1: | ||||||||||||||||||||||||||||||
| if is_graph_input(n.all_input_nodes[0]): | ||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||
| n = n.all_input_nodes[0] | ||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def annotate_fsdp_all_gather( | ||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes here:
|
||||||||||||||||||||||||||||||
| gm: torch.fx.GraphModule, reshard_after_forward: bool | ||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Force recompute all_gather nodes from simple fsdp in the graph. | ||||||||||||||||||||||||||||||
| This pass should be added in torch._inductor.config.joint_custom_post_pass | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| graph = gm.graph | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def force_recompute_node(node): | ||||||||||||||||||||||||||||||
| if reshard_after_forward: | ||||||||||||||||||||||||||||||
| node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| node.meta["recompute"] = CheckpointPolicy.MUST_SAVE | ||||||||||||||||||||||||||||||
|
Comment on lines
+58
to
+61
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does it work with full AC? IIUC these flags are for SAC API.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it works and I checked these cases (open compiler). Full AC+reshard_after_forward=False: Link Full AC+reshard_after_forward=True: Link wdym by these flags are for SAC API? My mental model is that: you send a joint graph to compile; before graph partitioning, it will scan these tags and decide if it wants recompute these ops in bwd or not. I think the partitioner does handle ac flag here; but i would also be curious where the actual bwd op recompute/save happens in compiler |
||||||||||||||||||||||||||||||
| # ac_graph_id is used in the partitioner to decide | ||||||||||||||||||||||||||||||
| # if two nodes which have AC applied come from a different | ||||||||||||||||||||||||||||||
| # AC regions. This is needed because nodes in the boundary | ||||||||||||||||||||||||||||||
| # of two AC regions are marked as MUST_SAVE. In our case | ||||||||||||||||||||||||||||||
| # we just add a large value of ac_graph_id so that | ||||||||||||||||||||||||||||||
| # all nodes we tag for recomputation do indeed get recomputed | ||||||||||||||||||||||||||||||
| # and are not influenced by other nodes in the graph with | ||||||||||||||||||||||||||||||
| # nearby ac_graph_id values | ||||||||||||||||||||||||||||||
| node.meta["ac_graph_id"] = 1000 | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Make all-gather nodes (and related nodes) recomputable, to circumvent | ||||||||||||||||||||||||||||||
| # https://github.com/pytorch/pytorch/issues/136433 | ||||||||||||||||||||||||||||||
| for node in graph.nodes: | ||||||||||||||||||||||||||||||
| if is_wait_tensor_from_fsdp(node): | ||||||||||||||||||||||||||||||
| ag_node = node.args[0] | ||||||||||||||||||||||||||||||
| force_recompute_node(ag_node) # all_gather | ||||||||||||||||||||||||||||||
| force_recompute_node(node) # wait_tensor | ||||||||||||||||||||||||||||||
| # Force-recompute slice that comes after wait | ||||||||||||||||||||||||||||||
| for user in node.users: | ||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||
| user.op == "call_function" | ||||||||||||||||||||||||||||||
| and user.target == torch.ops.aten.slice.Tensor | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
| force_recompute_node(user) | ||||||||||||||||||||||||||||||
| # Force-recompute potential dtype casts from all_gather | ||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||
| ag_node.all_input_nodes[0].op == "call_function" | ||||||||||||||||||||||||||||||
| and ag_node.args[0].target | ||||||||||||||||||||||||||||||
| == torch.ops.prims.convert_element_type.default | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
| force_recompute_node(ag_node.all_input_nodes[0]) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| return gm | ||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -112,41 +112,40 @@ def parallelize_llama( | |
| reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], | ||
| ) | ||
|
|
||
| match job_config.parallelism.fsdp_reshard_after_forward: | ||
| case "always": | ||
| reshard_after_forward = True | ||
| case "never": | ||
| reshard_after_forward = False | ||
| case "default": | ||
| # For PP, by default do not reshard after forward to avoid per-microbatch | ||
| # all-gathers, which can be expensive and non-overlapped | ||
| reshard_after_forward = not parallel_dims.pp_enabled | ||
| case _: | ||
| raise ValueError( | ||
| f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." | ||
| ) | ||
|
|
||
| model = data_parallel( | ||
| model, | ||
| parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], | ||
| mode=dp_mode, | ||
| ac_mode=job_config.activation_checkpoint.mode, | ||
| mp_policy=mp_policy, | ||
| reshard_after_forward=reshard_after_forward, | ||
| ) | ||
| logger.info( | ||
| "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode | ||
| ) | ||
|
|
||
| if job_config.compile.enable and "model" in job_config.compile.components: | ||
| torch._inductor.config.reorder_for_peak_memory = False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this flag for? Is the default True?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's default true here: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py#L406-L407; I talked to some compiler folks and they tell me it's to get out of box perf for inductor optimizations..... But the reordering will mess up simplefsdp op order. |
||
|
|
||
| match job_config.parallelism.fsdp_reshard_after_forward: | ||
| case "always": | ||
| fsdp_reshard_after_forward = True | ||
| case "never": | ||
| fsdp_reshard_after_forward = False | ||
| case "default": | ||
| # For PP, by default do not reshard after forward to avoid per-microbatch | ||
| # all-gathers, which can be expensive and non-overlapped | ||
| fsdp_reshard_after_forward = not parallel_dims.pp_enabled | ||
| case _: | ||
| raise ValueError( | ||
| f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." | ||
| ) | ||
|
|
||
| backend = ( | ||
| getattr(job_config.compile, "model_backend_override", None) | ||
| or job_config.compile.backend | ||
| ) | ||
| model = torch.compile( | ||
| model, | ||
| backend=get_compile_backend(backend), | ||
| backend=get_compile_backend(backend, fsdp_reshard_after_forward), | ||
| fullgraph=True, | ||
| ) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.