Skip to content

Conversation

@garrett361
Copy link
Contributor

@garrett361 garrett361 commented Oct 31, 2025

PR for removing scatter_add in the MoE implementation. scatter_add is somewhat problematic as it is non-deterministic due to the necessity of atomic adds for correctness.

The current draft PR replaces the scatter_add_ by batched matmuls and has multiple correctness tests. This replacement has the effect of making the implementation ~3x more accurate, by some measures, and ~3% more performant per triton.testing.do_bench, presumably due to avoiding atomic adds. Though for yet-unknown reasons this doesn't seem to result in e2e speedups. These numbers are for a DSv3 16B-like config.

The current form is not suitable for merging (hence, draft status). Future plan before merging, if this is a direction we want:

  • Remove helper test_moe.py, README_DEV.md and Makefile files.
  • Delete the MoEOld class.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 31, 2025
print(f"Speedup: {moe_old_time_ms/moe_time_ms=}")


if __name__ == "__main__":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On an H100 I'm getting these results for this script:

❯ python tests/unit_tests/test_moe.py

ACCURACY VS FFN: 10 iterations

mean_moe_old_rel_err.mean()=tensor(0.0097), mean_moe_old_rel_err.std()=tensor(4.2804e-05)
mean_moe_rel_err.mean()=tensor(0.0025), mean_moe_rel_err.std()=tensor(1.4838e-05)
mean_moe_old_rel_err.mean()/mean_moe_rel_err.mean()=tensor(3.8742)

TRITON BENCH: perf_seqlen=4096 perf_bsz=4 warmups=100 repeats=1000

moe_old_time_ms=19.720215759277345
moe_time_ms=19.050707193521355
Speedup: moe_old_time_ms/moe_time_ms=1.0351435019684556

MoEOld AND MoE CLOSE: score_before_experts=True

MoEOld AND MoE CLOSE: score_before_experts=False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait where is this error coming from for the new implementation, if it's run-to-run deterministic (I'm assuming the feed forward one is also deterministic)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By err you mean the standard deviation across runs in the ACCURACY VS FFN: 10 iterations section? If so that is from me running ten iterations with different random seeds each time. Just collecting better statistics than a single run, not directly testing run-to-run determinism here.

Each run I computed the relative error between an FFN layer and the should-be-equivalent MoE layer and appended the result to the mean_moe_rel_err tensor, e.g. And the printout above shows the mean and std across those ten runs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a determinism test locally:

def test_determinism(self):
    torch.manual_seed(42)
    moe_old, moe = self._get_moe_old_and_moe_layers(score_before_experts=False)
    inputs = torch.randn(
        self.bsz,
        self.seqlen,
        self.dim,
        device=self.device,
        dtype=torch.bfloat16,
    )

    out_moe_old = []
    out_moe = []
    with torch.no_grad():
        for _ in range(100):
            out_moe_old.append(moe_old(inputs))
            out_moe.append(moe(inputs))
    out_moe_old = torch.stack(out_moe_old, dim=0)
    out_moe = torch.stack(out_moe, dim=0)
    print(f"{out_moe_old.std(dim=0).mean()=}")
    print(f"{out_moe.std(dim=0).mean()=}")

The printout shows the new impl is deterministic, while the old is not:

⬢ [podman] ❯ pytest -rsfP tests/llama3_moe/test_moe.py -k determinism
===================================================================================================== test session starts =====================================================================================================
platform linux -- Python 3.11.9, pytest-8.4.2, pluggy-1.6.0
rootdir: /app/torchtitan
configfile: pyproject.toml
plugins: cov-7.0.0, dtest-0.0.0, anyio-4.11.0, typeguard-4.4.4
collected 4 items / 3 deselected / 1 selected

tests/llama3_moe/test_moe.py .                                                                                                                                                                                          [100%]

=========================================================================================================== PASSES ============================================================================================================
_________________________________________________________________________________________________ TestModel.test_determinism __________________________________________________________________________________________________
---------------------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------------------
out_moe_old.std(dim=0).mean()=tensor(0.0659, device='cuda:0', dtype=torch.bfloat16)
out_moe.std(dim=0).mean()=tensor(0., device='cuda:0', dtype=torch.bfloat16)
=============================================================================================== 1 passed, 3 deselected in 8.99s ===============================================================================================

CC @rakkit who I recall discussed some MoE non-determinism findings in another issue/PR.

@garrett361
Copy link
Contributor Author

CC @tianyu-l @rakkit , let me know your thoughts on this?

@garrett361 garrett361 changed the title remove scatter_add_ in MoE implementation remove scatter_add in MoE implementation Oct 31, 2025
@tianyu-l
Copy link
Contributor

@garrett361 Actually I was thinking about doing this, too. Plan sounds good to me. Thanks a lot!

@garrett361 garrett361 marked this pull request as ready for review October 31, 2025 19:45
@garrett361
Copy link
Contributor Author

Cleaned it up and removed testing code @tianyu-l, fyi.

# and outputs in top_scores and routed_input, respectively.
top_scores = top_scores.flatten()
top_scores[token_indices_experts_sorted] = top_scores_experts_sorted
routed_input[token_indices_experts_sorted] = routed_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also save some allocations: store unsorted scores and outputs in top_scores and routed_input, respectively.

Hmm please educate me more here:
From Python level this might save memory, but wouldn't the routed_input activations be saved by the autograd engine anyways?

Are you observing meaningful savings? If not, I'd wish we separate the concern and focus on run-to-run determinism in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is probably overly-terse, and also maybe I'm just wrong?

The comment is in reference to an alternative implementation where we do something like:

top_scores = top_scores.flatten()
top_scores[token_indices_experts_sorted] = top_scores_experts_sorted
# new alloc from empty_like, avoided in the current code
routed_output_sorted = torch.empty_like(routed_output)
routed_output_sorted[token_indices_experts_sorted] = routed_output

or

top_scores = top_scores.flatten()
top_scores[token_indices_experts_sorted] = top_scores_experts_sorted
# new alloc from clone, avoided in the current code
routed_output[token_indices_experts_sorted] = routed_output.clone()

The clone is needed in the second case, because otherwise routed_output[token_indices_experts_sorted] = routed_output leads to

RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.

Which is what I initially tried.

So the point here is about avoiding a clone or empty_like or other similar call that would cause a new allocation.

For training, both routed_input and routed_output are surely in the computational graph still: this doesn't avoid that.

LMK if I'm wrong in any of the above, and if you want code or comment changes or any profiling of the different options here.

top_scores = top_scores.flatten()
top_scores[token_indices_experts_sorted] = top_scores_experts_sorted
routed_input[token_indices_experts_sorted] = routed_output
routed_input = routed_input.reshape(bs * slen, -1, dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried that when EP > 1, ETP = 1, the routed_input no longer has length bs * slen * top_k, given the code here https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L254

Assuming even split, which is the only split supported now, maybe we can do reshape(-1, top_k, dim)? But in the long term, our goal is to make this part DTensor rather than plain tensor so model code don't need to worry too much about parallelism.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't aware of this code path at all (and can't say I fully understand it, from a quick glance). But in any case, if we're ever doing some leading-dim splitting inside the MoE layer, then yes, definitely preferable to reshape with explicit top_k's rather than bs and slen factors.

I'm a little confused about shapes everywhere, though, in this path. Seems like:

  • All of the reorderer's outputs are split on dim0.
  • The above causes routed_input to be split on dim0
  • And also out_experts to be split on dim0?

If the last one is true, then I think I also have problems in the case where there is a shared expert since it seems like (out + out_experts).reshape(bs, slen, dim) will fail on shapes. Is that right? I see that this step wouldn't fail on upstream due to the scatter_add, but I don't yet understand why it is correct.

print(f"Speedup: {moe_old_time_ms/moe_time_ms=}")


if __name__ == "__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait where is this error coming from for the new implementation, if it's run-to-run deterministic (I'm assuming the feed forward one is also deterministic)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants