-
Notifications
You must be signed in to change notification settings - Fork 594
[mxfp8 MoE training] Support mxfp8 all to all in expert parallel #1765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
4527f8b to
bba9c6a
Compare
fde6de2 to
a48e631
Compare
a48e631 to
89bac2b
Compare
89bac2b to
0a0d676
Compare
0a0d676 to
751a472
Compare
af79e9a to
bdbdd45
Compare
bdbdd45 to
bf6d943
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For activation checkpointing with mxfp8 a2a, should we change this save list to mimic what happens in bf16?
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/llama4/infra/parallelize.py#L42
Actually I have a question:
In the save list we don't have quantized matmul. Wouldn't it create unfair comparison between bf16 vs. quantized runs before?
| Note that this is still an experimental feature. | ||
| """ | ||
|
|
||
| expert_parallel_a2a_dispatch_impl: Literal["default", "mxfp8"] = "default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be consistent with how to config low-precision in torchtitan, let's put these under job_config.quantize.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't that cause a conflict with EP a2a impl in Parallelism once you add NVSHMEM impl? we could have the quantize a2a impl override the Parallelism a2a impl, but that may be unclear to users, what do you think?
| """ | ||
|
|
||
| def __init__( | ||
| self, a2a_dispatch_impl: str = "default", a2a_combine_impl: str = "default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding configs to constructor, a more object-oriented way would be
- refactor
ExpertParallelto haveself._all_to_all_dispatch_fnandself._all_to_all_combine_fn, both with default valueall_to_all_single_autograd. - Let
MXExpertParallelinheritExpertParallelwhose constructor set those variables to use mxfp8 depending on the mxfp8 config. The class should sit underquantize/mxfolder.
@tianyu-l I've been using AC=none for my benchmarks, but yes we need to update this list for mxfp8 MoE training. |
|
Closing in favor of #1912 to avoid nasty merge conflicts during the refactor |
Summary
"default"or"mxfp8"impl"mxfp8"impl uses torchao's newto_mxfp8_a2a_dequant, which has the exact same API as functional collectiveall_to_all_single_autogradand is differentiable, so it can be used as a drop-in replacement for the default a2a impl.to_mxfp8_a2a_dequantworks as follows:Performance
Single node benchmarks with 4xB200
Llama4 16e default configs; FSDP=4, EP=4; AC=none; compile=True; seq_len=8192; local_bs=8
Reduced num layers from 48 -> 2 to avoid OOM in single node setting
Debug model config:
Additional context on design/implementation choices
Additional background on motivation
30% of llama4 model profiled runtime is all2all comms
47% avg runtime devoted to MoE comms in profiled OSS models