Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,53 @@ def _get_gen_batch(self, batch: DataProto) -> DataProto:

return gen_batch

@staticmethod
def _summarize_reward_extras(reward_extra_infos: dict[str, list]) -> dict[str, float]:
metrics: dict[str, float] = {}
if not reward_extra_infos:
return metrics

for key, values in reward_extra_infos.items():
if not values:
continue

flattened: list = []
for value in values:
if isinstance(value, list | tuple):
flattened.extend(value)
else:
flattened.append(value)

numeric_vals: list[float] = []
for value in flattened:
scalar: float | None = None
if isinstance(value, torch.Tensor):
if value.numel() == 1:
scalar = float(value.item())
elif isinstance(value, np.ndarray):
if value.size == 1:
scalar = float(value.item())
elif isinstance(value, np.floating | np.integer | int | float | bool):
scalar = float(value)
else:
try:
scalar = float(value)
except (TypeError, ValueError):
scalar = None

if scalar is not None:
numeric_vals.append(scalar)
Comment on lines +528 to +553
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for processing reward_extra_infos is not fully robust. It only flattens a single level of nested list or tuple objects and silently ignores multi-element torch.Tensor or np.ndarray objects. This can lead to incomplete data and misleading metrics.

For example, if a value is np.array([1, 2]) or a nested list like [[1, 2]], its contents will not be included in the statistics.

A more robust approach would be to recursively flatten all iterable structures and then attempt to convert all resulting items to floats. This ensures all numeric data is captured, regardless of nesting or type.

            # 1. Flatten all nested lists, tuples, tensors, and arrays
            flattened: list = []
            items_to_flatten = list(values)
            while items_to_flatten:
                item = items_to_flatten.pop(0)
                if isinstance(item, (list, tuple)):
                    items_to_flatten.extend(item)
                elif isinstance(item, (torch.Tensor, np.ndarray)):
                    items_to_flatten.extend(item.flatten().tolist())
                else:
                    flattened.append(item)

            # 2. Convert flattened items to numeric values
            numeric_vals: list[float] = []
            for value in flattened:
                try:
                    numeric_vals.append(float(value))
                except (TypeError, ValueError):
                    # Ignore values that cannot be converted to float
                    pass


if not numeric_vals:
continue

numeric_array = np.array(numeric_vals, dtype=float)
metrics[f"reward_extra/{key}/mean"] = float(np.mean(numeric_array))
metrics[f"reward_extra/{key}/max"] = float(np.max(numeric_array))
metrics[f"reward_extra/{key}/min"] = float(np.min(numeric_array))

return metrics

def _validate(self):
data_source_lst = []
reward_extra_infos_dict: dict[str, list] = defaultdict(list)
Expand Down Expand Up @@ -975,6 +1022,7 @@ def fit(self):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
reward_extra_metrics: dict[str, float] = {}

with marked_timer("start_profile", timing_raw):
self._start_profiling(
Expand Down Expand Up @@ -1027,6 +1075,7 @@ def fit(self):
batch.batch["reward_baselines"] = reward_baseline_tensor

del gen_baseline_batch, gen_baseline_output

# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
Expand Down Expand Up @@ -1054,6 +1103,7 @@ def fit(self):
future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)
else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
reward_extra_metrics.update(self._summarize_reward_extras(reward_extra_infos_dict))

# recompute old_log_probs
with marked_timer("old_log_prob", timing_raw, color="blue"):
Expand Down Expand Up @@ -1093,6 +1143,7 @@ def fit(self):
reward_extra_infos_dict: dict[str, list]
if self.config.reward_model.launch_reward_fn_async:
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
reward_extra_metrics.update(self._summarize_reward_extras(reward_extra_infos_dict))
batch.batch["token_level_scores"] = reward_tensor

if reward_extra_infos_dict:
Expand Down Expand Up @@ -1199,6 +1250,7 @@ def fit(self):
"training/epoch": epoch,
}
)
metrics.update(reward_extra_metrics)
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
Expand Down