diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index e9986b9974..3361ef2abc 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -67,10 +67,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: class ExpertParallel(ParallelStyle): def __init__(self): super().__init__() - self.input_splits = None - self.output_splits = None - self.input_shape = None - self.permuted_indices = None # performing all-to-all dispatch on the input def _token_dispatch(self, mod, inputs, device_mesh): @@ -103,14 +99,14 @@ def _token_dispatch(self, mod, inputs, device_mesh): .sum(dim=1) .to(torch.device("cpu"), non_blocking=False) ) - self.input_splits = input_splits.tolist() - self.output_splits = output_splits.tolist() + input_splits = input_splits.tolist() + output_splits = output_splits.tolist() # perform all-to-all routed_input = all_to_all_single_autograd( routed_input, - self.output_splits, - self.input_splits, + output_splits, + input_splits, device_mesh.get_group(), ) @@ -127,15 +123,22 @@ def _token_dispatch(self, mod, inputs, device_mesh): # of GroupedExperts, as it does not need padding. ( - self.input_shape, + input_shape, routed_input, - self.permuted_indices, + permuted_indices, num_tokens_per_expert_group, ) = _permute( routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts ) - return routed_input, num_tokens_per_expert_group + return ( + routed_input, + num_tokens_per_expert_group, + input_shape, + permuted_indices, + input_splits, + output_splits, + ) @staticmethod def _partition_fn(name, mod, device_mesh): @@ -145,15 +148,20 @@ def _partition_fn(name, mod, device_mesh): mod.register_parameter(name, dist_param) # performing all-to-all combine on the output - def _token_combine(self, mod, routed_output, device_mesh): - routed_output = _unpermute( - routed_output, self.input_shape, self.permuted_indices - ) + def _token_combine(self, mod, mod_outputs, device_mesh): + ( + routed_output, + input_shape, + permuted_indices, + input_splits, + output_splits, + ) = mod_outputs + routed_output = _unpermute(routed_output, input_shape, permuted_indices) routed_output = all_to_all_single_autograd( routed_output, - self.input_splits, - self.output_splits, + input_splits, + output_splits, device_mesh.get_group(), ) return routed_output @@ -204,9 +212,9 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh): nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(1)])), ) # Column-wise sharding - def _token_combine(self, mod, routed_output, device_mesh): + def _token_combine(self, mod, mod_outputs, device_mesh): # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh - return super()._token_combine(mod, routed_output, device_mesh["ep"]) + return super()._token_combine(mod, mod_outputs, device_mesh["ep"]) def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: return distribute_module( diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 00ec53310e..a002b732ed 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -39,7 +39,7 @@ local_batch_size = 4 seq_len = 4096 max_norm = 1.0 # grad norm clipping steps = 1000 -dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] data_parallel_replicate_degree = 1 @@ -65,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable=true +enable = true components = ["loss"] # ["model", "loss"] [quantize.linear.float8] diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 295e2193a5..0ab4277a88 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -143,6 +143,10 @@ def forward( self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor, + input_shape, + permuted_indices, + input_splits, + output_splits, ) -> torch.Tensor: if isinstance(self.w1, DTensor): # Convert parameters from DTensors to plain Tensors, to work with @@ -166,9 +170,11 @@ def forward( run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm) else: run_experts_fn = _run_experts_grouped_mm - return run_experts_fn(w1, w2, w3, x, num_tokens_per_expert) + out = run_experts_fn(w1, w2, w3, x, num_tokens_per_expert) else: - return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert) + out = _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert) + + return (out, input_shape, permuted_indices, input_splits, output_splits) def init_weights(self, init_std: float): nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)