@@ -94,10 +94,12 @@ def parallelize_llama(
9494 )
9595 maybe_enable_async_tp (job_config , world_mesh ["tp" ])
9696
97- # Assume 2x tokens per EP rank in the worst case.
98- # TODO: explore other options
97+ # Worst case = single expert receives all tokens
98+ # TODO: explore using token dropping to avoid this huge overallocation
9999 max_tokens_per_ep_rank = (
100- job_config .training .seq_len * job_config .training .local_batch_size * 2
100+ job_config .training .seq_len
101+ * job_config .training .local_batch_size
102+ * model .model_args .moe_args .num_experts
101103 )
102104 if parallel_dims .tp_enabled or parallel_dims .ep_enabled :
103105 apply_moe_ep_tp (
@@ -446,7 +448,7 @@ def apply_moe_ep_tp(
446448 ep_tp_mesh : DeviceMesh | None ,
447449 etp_enabled : bool ,
448450 a2a_impl : str ,
449- max_tokens_per_ep_rank : int ,
451+ max_tokens_per_ep_rank : int = - 1 , # Only used for mxfp8 a2a
450452):
451453 for transformer_block in model .layers .values ():
452454 if not transformer_block .moe_enabled :
@@ -504,7 +506,9 @@ def apply_moe_ep_tp(
504506 experts_plan = ExpertTensorParallel (tp_mesh = tp_mesh , ep_mesh = ep_mesh )
505507 else :
506508 experts_mesh = ep_mesh
507- experts_plan = ExpertParallel (a2a_impl = a2a_impl )
509+ experts_plan = ExpertParallel (
510+ a2a_impl = a2a_impl , max_tokens_per_ep_rank = max_tokens_per_ep_rank
511+ )
508512
509513 parallelize_module (
510514 module = transformer_block .moe .experts ,
0 commit comments