Skip to content

Commit 8ec599e

Browse files
committed
[bugfix] Fix aux loss & (gradient_accumulation_steps & loss_scale) (#5823)
1 parent 269f43f commit 8ec599e

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

swift/trainers/trainers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
402402
if self.model.model_info.is_moe_model and self.args.router_aux_loss_coef is not None:
403403
aux_loss = outputs.get('aux_loss')
404404
if aux_loss is not None:
405+
if num_items_in_batch is not None:
406+
aux_loss = aux_loss * ((labels[:, 1:] != -100).sum() / num_items_in_batch)
405407
loss = loss + self.args.router_aux_loss_coef * aux_loss.to(loss.device)
406408

407409
if self.template.sequence_parallel_size > 1:

0 commit comments

Comments
 (0)