-
Notifications
You must be signed in to change notification settings - Fork 597
Workaround AC HOP mutation issue when tracing token dispatch #1984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #1984, branch: xmfan/stack/2
| input_shape, | ||
| permuted_indices, | ||
| input_splits, | ||
| output_splits, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These shouldn't be exposed to single-device model code. Plus, I don't think it will work if EP is not used.
If it's getting too hard, maybe we should use local_map / to_local to re-implement MoE.
|
Thank you for the fix! Do you think it would require fewer user-side changes if we reimplemented apply_ac as a graph pass? |
| max_norm = 1.0 # grad norm clipping | ||
| steps = 1000 | ||
| dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) | ||
| dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: toml file config changed.
|
This change is only needed if you use compile(torch.utils.checkpoint(, so graph pass wouldn't need it. but if you use both eager and graph-based, you will need this again |
Yes, what I meant is that if we're going for a compiler-based approach to distributed parallelism in simplefsdp, it would make sense to have a specialized apply_ac function that’s also compiler-based. (and users are not allowed to use eager checkpoint to implement ac) |
FIXES #1935
Stacked PRs:
tlparse: https://fburl.com/sqxd6c0w
Workaround AC HOP mutation issue when tracing token dispatch
TORCH_COMPILE_FORCE_DISABLE_CACHES=1 HF_TOKEN=<token> HF_HUB_DISABLE_XET=1 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" with-proxy ./run_train.sh --model.name simple_fsdp.deepseek_v3This is a problem for SimpleFSDP where we want to fullgraph the entire model, these "mutation" cause graph break
It is less of a problem outside SimpleFSDP, because we don't currently compile token dispatch