Skip to content

Commit 269f43f

Browse files
hjh0119Jintao-Huang
authored andcommitted
[grpo] fix log std_zero (#5813)
* fix log std0 * fix log std
1 parent 588c25c commit 269f43f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@ def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: tor
984984
group_rewards = rewards.view(-1, self.num_generations)
985985
rewards_mean = group_rewards.mean(-1).mean().item()
986986
rewards_std = group_rewards.std(-1).mean().item()
987-
is_std_zero = torch.isclose(rewards.std(dim=0), torch.zeros_like(rewards.std(dim=0)))
987+
is_std_zero = torch.isclose(group_rewards.std(dim=1), torch.zeros_like(group_rewards.std(dim=1)))
988988

989989
self._metrics[mode]['reward'].append(rewards_mean)
990990
self._metrics[mode]['reward_std'].append(rewards_std)

0 commit comments

Comments
 (0)