Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions torchtitan/models/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,6 @@ def forward(
)

top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted]
token_indices_experts_sorted = token_indices_experts_sorted // self.top_k

return (
top_scores_experts_sorted,
Expand Down Expand Up @@ -414,7 +413,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
bs, slen, dim = x.shape
x = x.view(-1, dim)

# top_scores and selected_experts_indices shape (bs*slen*top_k,)
# top_scores shape (bs*slen, top_k)
# selected_experts_indices shape (bs*slen, top_k)
# num_tokens_per_expert shape (num_experts,)
(
top_scores,
Expand Down Expand Up @@ -445,12 +445,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
) = self.reorderer(top_scores, selected_experts_indices)

# shape (bs*slen*top_k, dim)
token_indices_experts_sorted = token_indices_experts_sorted.reshape(
-1, 1
).expand(-1, dim)

# shape (bs*slen*top_k, dim)
routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted)
routed_input = x[token_indices_experts_sorted // self.router.top_k]

if self.score_before_experts:
routed_input = (
Expand All @@ -464,22 +459,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# shared expert
# Note: we execute the shared expert before scoring the output of the routed expert
# to "implicitly" overlap the shared expert compute with token combine communication
if self.shared_experts is not None:
out = self.shared_experts(x)
else:
out = torch.zeros_like(x)
out = self.shared_experts(x) if self.shared_experts is not None else None

if not self.score_before_experts:
routed_output = (
routed_output.to(torch.float32)
* top_scores_experts_sorted.reshape(-1, 1)
).to(x.dtype)
# Unsort scores and routed outputs. Also save some allocations: store unsorted scores
# and outputs in top_scores and routed_input, respectively.
top_scores = top_scores.flatten()
top_scores[token_indices_experts_sorted] = top_scores_experts_sorted
routed_input[token_indices_experts_sorted] = routed_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also save some allocations: store unsorted scores and outputs in top_scores and routed_input, respectively.

Hmm please educate me more here:
From Python level this might save memory, but wouldn't the routed_input activations be saved by the autograd engine anyways?

Are you observing meaningful savings? If not, I'd wish we separate the concern and focus on run-to-run determinism in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is probably overly-terse, and also maybe I'm just wrong?

The comment is in reference to an alternative implementation where we do something like:

top_scores = top_scores.flatten()
top_scores[token_indices_experts_sorted] = top_scores_experts_sorted
# new alloc from empty_like, avoided in the current code
routed_output_sorted = torch.empty_like(routed_output)
routed_output_sorted[token_indices_experts_sorted] = routed_output

or

top_scores = top_scores.flatten()
top_scores[token_indices_experts_sorted] = top_scores_experts_sorted
# new alloc from clone, avoided in the current code
routed_output[token_indices_experts_sorted] = routed_output.clone()

The clone is needed in the second case, because otherwise routed_output[token_indices_experts_sorted] = routed_output leads to

RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.

Which is what I initially tried.

So the point here is about avoiding a clone or empty_like or other similar call that would cause a new allocation.

For training, both routed_input and routed_output are surely in the computational graph still: this doesn't avoid that.

LMK if I'm wrong in any of the above, and if you want code or comment changes or any profiling of the different options here.

routed_input = routed_input.reshape(-1, self.router.top_k, dim)
top_scores = top_scores.reshape(-1, 1, self.router.top_k)
out_experts = (
torch.bmm(top_scores, routed_input.float()).to(x.dtype).squeeze(1)
)
else:
# Unsort routed outputs and save an allocation: store unsorted outputs in routed_input
routed_input[token_indices_experts_sorted] = routed_output
out_experts = routed_input.reshape(-1, self.router.top_k, dim).sum(dim=1)

out = out.scatter_add(
dim=0, index=token_indices_experts_sorted, src=routed_output
)
out = out.reshape(bs, slen, dim)
return out
if out is None:
return out_experts.reshape(bs, slen, dim)
return (out + out_experts).reshape(bs, slen, dim)

def init_weights(
self,
Expand Down