Skip to content

Commit 3e6510c

Browse files
committed
Implement auxiliary-loss-free load balancing
1 parent 5b4855d commit 3e6510c

File tree

4 files changed

+74
-1
lines changed

4 files changed

+74
-1
lines changed

llm/run_pretrain.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from paddlenlp.trainer import (
3030
FP8QuantWeightCallback,
31+
MoECorrectionBiasAdjustCallback,
3132
PdArgumentParser,
3233
StepFlexToken,
3334
Trainer,
@@ -571,6 +572,9 @@ def main():
571572

572573
callbacks = [StepFlexToken(), FP8QuantWeightCallback()]
573574

575+
if getattr(config, "topk_method", None) == "noaux_tc":
576+
callbacks += [MoECorrectionBiasAdjustCallback()]
577+
574578
trainer = PretrainingTrainer(
575579
model=model,
576580
args=training_args,

paddlenlp/trainer/trainer_callback.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
import numpy as np
2828
from tqdm.auto import tqdm
2929

30+
import paddle
31+
import paddle.distributed as dist
32+
from paddle.distributed.fleet import fleet
33+
34+
from paddlenlp.transformers.moe_gate import PretrainedMoEGate
3035
from paddlenlp.transformers.moe_utils import offload, reload
3136
from paddlenlp.utils.log import logger
3237

@@ -44,6 +49,7 @@
4449
"EarlyStoppingCallback",
4550
"StepFlexToken",
4651
"FP8QuantWeightCallback",
52+
"MoECorrectionBiasAdjustCallback",
4753
]
4854

4955

@@ -671,3 +677,59 @@ def on_optimizer_begin(self, args, state, control, **kwargs):
671677
if (not g_shard_bypass_dygraph_optimizer) and hasattr(model, "fp8_quant_weight"):
672678
for name in self.moe_weights_name:
673679
reload(optimizer._master_weights[name])
680+
681+
682+
class MoECorrectionBiasAdjustCallback(TrainerCallback):
683+
"""used for moe aux loss free balance"""
684+
685+
def __init__(self, lr=0.001, use_mp=False):
686+
super().__init__()
687+
self.update_lr = lr
688+
self.use_mp = use_mp
689+
690+
def on_optimizer_end(self, args, state, control, **kwargs):
691+
model = kwargs["model"]
692+
693+
biases = []
694+
usages = []
695+
696+
def get_stat(layer):
697+
if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc":
698+
biases.append(layer.e_score_correction_bias)
699+
usages.append(layer.expert_usage)
700+
701+
model.apply(get_stat)
702+
703+
usages_tensor = paddle.stack(usages, 0) # [num_layers, num_local_experts]
704+
if not hasattr(fleet, "_hcg"):
705+
dist.all_reduce(usages_tensor)
706+
return
707+
708+
hcg = fleet.get_hybrid_communicate_group()
709+
mp_group = hcg.get_model_parallel_group()
710+
dp_group = hcg.get_data_parallel_group()
711+
sd_group = hcg.get_sharding_parallel_group()
712+
713+
if self.use_mp and mp_group.nranks > 1:
714+
dist.all_reduce(usages_tensor, group=mp_group)
715+
if dp_group.nranks > 1:
716+
dist.all_reduce(usages_tensor, group=dp_group)
717+
if sd_group.nranks > 1:
718+
dist.all_reduce(usages_tensor, group=sd_group)
719+
720+
usages_mean = usages_tensor.mean(-1, keepdim=True)
721+
update = paddle.sign(usages_mean - usages_tensor) * self.update_lr
722+
update_list = list(update)
723+
724+
print('on_optimizer_end bias:', paddle.stack(biases, 0).numpy())
725+
print('on_optimizer_end usage:', usages_tensor.numpy())
726+
print('on_optimizer_end update:', update.numpy())
727+
728+
def update_bias(layer):
729+
if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc":
730+
with paddle.no_grad():
731+
if not layer.weight.stop_gradient:
732+
biases.pop(0).add_(update_list.pop(0))
733+
usages.pop(0).zero_()
734+
735+
model.apply(update_bias)

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,11 @@ def __init__(
924924
default_initializer=nn.initializer.Constant(0.0),
925925
)
926926
self.e_score_correction_bias.is_distributed = True
927+
self.expert_usage = paddle.zeros(
928+
shape=[num_experts],
929+
dtype=paddle.int64,
930+
)
931+
self.expert_usage.stop_gradient = True
927932

928933
if self.using_post_norm_recompute:
929934
assert norm_weight is not None and norm_eps is not None
@@ -969,6 +974,8 @@ def forward(self, hidden_states):
969974
scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop(
970975
scores
971976
) # (scores, routing_map, exp_counts, l_aux, l_zloss)
977+
with paddle.no_grad():
978+
self.expert_usage += exp_counts
972979
ret = (scores, routing_map, l_aux, l_zloss)
973980
else:
974981
ret = self.topkgating(scores) # (capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss)

paddlenlp/transformers/moe_gate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def _topk_noaux_tc(
301301
assert n_experts % n_group == 0, "n_experts must be divisible by n_groups"
302302

303303
assert self.e_score_correction_bias is not None, "e_score_correction_bias is None"
304-
scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.unsqueeze(0)
304+
scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.detach().unsqueeze(0)
305305
reshape_tmp_rst = scores_for_choice.reshape([bsz_seq_len, self.n_group, -1])
306306
top_k = min(reshape_tmp_rst.shape[2], 2)
307307
group_scores = reshape_tmp_rst.topk(top_k, axis=-1)[0].sum(axis=-1) # fmt:skip [n, n_group]

0 commit comments

Comments
 (0)