diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 295e2193a..2b7724f32 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -345,7 +345,6 @@ def forward( ) top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] - token_indices_experts_sorted = token_indices_experts_sorted // self.top_k return ( top_scores_experts_sorted, @@ -414,7 +413,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: bs, slen, dim = x.shape x = x.view(-1, dim) - # top_scores and selected_experts_indices shape (bs*slen*top_k,) + # top_scores shape (bs*slen, top_k) + # selected_experts_indices shape (bs*slen, top_k) # num_tokens_per_expert shape (num_experts,) ( top_scores, @@ -445,12 +445,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) = self.reorderer(top_scores, selected_experts_indices) # shape (bs*slen*top_k, dim) - token_indices_experts_sorted = token_indices_experts_sorted.reshape( - -1, 1 - ).expand(-1, dim) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + routed_input = x[token_indices_experts_sorted // self.router.top_k] if self.score_before_experts: routed_input = ( @@ -464,22 +459,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # shared expert # Note: we execute the shared expert before scoring the output of the routed expert # to "implicitly" overlap the shared expert compute with token combine communication - if self.shared_experts is not None: - out = self.shared_experts(x) - else: - out = torch.zeros_like(x) + out = self.shared_experts(x) if self.shared_experts is not None else None if not self.score_before_experts: - routed_output = ( - routed_output.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) + # Unsort scores and routed outputs. Also save some allocations: store unsorted scores + # and outputs in top_scores and routed_input, respectively. + top_scores = top_scores.flatten() + top_scores[token_indices_experts_sorted] = top_scores_experts_sorted + routed_input[token_indices_experts_sorted] = routed_output + routed_input = routed_input.reshape(-1, self.router.top_k, dim) + top_scores = top_scores.reshape(-1, 1, self.router.top_k) + out_experts = ( + torch.bmm(top_scores, routed_input.float()).to(x.dtype).squeeze(1) + ) + else: + # Unsort routed outputs and save an allocation: store unsorted outputs in routed_input + routed_input[token_indices_experts_sorted] = routed_output + out_experts = routed_input.reshape(-1, self.router.top_k, dim).sum(dim=1) - out = out.scatter_add( - dim=0, index=token_indices_experts_sorted, src=routed_output - ) - out = out.reshape(bs, slen, dim) - return out + if out is None: + return out_experts.reshape(bs, slen, dim) + return (out + out_experts).reshape(bs, slen, dim) def init_weights( self,