Skip to content

Commit 89bac2b

Browse files
overallocate max output tokens per ep rank
1 parent 701f8fc commit 89bac2b

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,12 @@ def parallelize_llama(
9494
)
9595
maybe_enable_async_tp(job_config, world_mesh["tp"])
9696

97-
# Assume 2x tokens per EP rank in the worst case.
98-
# TODO: explore other options
97+
# Worst case = single expert receives all tokens
98+
# TODO: explore using token dropping to avoid this huge overallocation
9999
max_tokens_per_ep_rank = (
100-
job_config.training.seq_len * job_config.training.local_batch_size * 2
100+
job_config.training.seq_len
101+
* job_config.training.local_batch_size
102+
* model.model_args.moe_args.num_experts
101103
)
102104
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
103105
apply_moe_ep_tp(
@@ -446,7 +448,7 @@ def apply_moe_ep_tp(
446448
ep_tp_mesh: DeviceMesh | None,
447449
etp_enabled: bool,
448450
a2a_impl: str,
449-
max_tokens_per_ep_rank: int,
451+
max_tokens_per_ep_rank: int = -1, # Only used for mxfp8 a2a
450452
):
451453
for transformer_block in model.layers.values():
452454
if not transformer_block.moe_enabled:
@@ -504,7 +506,9 @@ def apply_moe_ep_tp(
504506
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
505507
else:
506508
experts_mesh = ep_mesh
507-
experts_plan = ExpertParallel(a2a_impl=a2a_impl)
509+
experts_plan = ExpertParallel(
510+
a2a_impl=a2a_impl, max_tokens_per_ep_rank=max_tokens_per_ep_rank
511+
)
508512

509513
parallelize_module(
510514
module=transformer_block.moe.experts,

0 commit comments

Comments
 (0)