Skip to content

length penalty in reward function #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions verl/trainer/config/generation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ rollout:
disable_log_stats: True
enable_chunked_prefill: True
n: 1
length_penalty:
enabled: False # Set to True to enable length penalty during generation
alpha: 0.0 # Alpha parameter: positive favors longer sequences, negative favors shorter
min_length: 0 # Minimum sequence length before applying penalty
max_length: null # Maximum sequence length (null means no maximum)
actor:
strategy: fsdp # This is for backward-compatibility
ulysses_sequence_parallel_size: 1 # sp size
Expand Down
8 changes: 8 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ reward_model:
use_dynamic_bsz: ${critic.use_dynamic_bsz}
max_length: null
launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob
length_penalty:
enabled: False # Set to True to enable length penalty
alpha: 0.0 # Alpha parameter: positive favors longer sequences, negative favors shorter
min_length: 0 # Minimum sequence length before applying penalty
max_length: null # Maximum sequence length (null means no maximum)

custom_reward_function:
path: null
Expand Down Expand Up @@ -252,6 +257,9 @@ trainer:
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
track_token_lengths: true
positive_reward_threshold: 0.5
token_length_tag_prefix: "token_length"
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or disable or resume_path if resume_from_path is set
resume_from_path: null
Expand Down
8 changes: 8 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ reward_model:
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
reward_manager: naive
launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob
length_penalty:
enabled: False # Set to True to enable length penalty
alpha: 0.0 # Alpha parameter: positive favors longer sequences, negative favors shorter
min_length: 0 # Minimum sequence length before applying penalty
max_length: null # Maximum sequence length (null means no maximum)

custom_reward_function:
path: null
Expand Down Expand Up @@ -214,6 +219,9 @@ trainer:
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
track_token_lengths: true
positive_reward_threshold: 0.5
token_length_tag_prefix: "token_length"
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or disable or resume_path if resume_from_path is set
resume_from_path: null
Expand Down
93 changes: 92 additions & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,65 @@
global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix)
metrics.update(global_balance_stats)

def log_token_lengths(self, data, step, metrics=None):
"""Log average token lengths for positive and negative responses."""
if not self.config.trainer.get('track_token_lengths', False):
return

threshold = self.config.trainer.get('positive_reward_threshold', 0.5)
tag_prefix = self.config.trainer.get('token_length_tag_prefix', 'token_length')

try:
if 'acc' in data.batch:
rewards = data.batch['acc'].cpu().tolist()
elif 'token_level_scores' in data.batch:
rewards = data.batch['token_level_scores'].sum(-1).cpu().tolist()
else:
print("Warning: Cannot track token lengths - no reward information found")
return
except Exception as e:
print(f"Warning: Error computing rewards for token length tracking: {e}")
return

responses = data.batch['responses']
attention_mask = data.batch['attention_mask']
prompt_len = data.batch['prompts'].shape[-1]
valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1)

positive_tokens = []
negative_tokens = []

for i, reward in enumerate(rewards):
length = valid_response_lengths[i].item()
response_tokens = responses[i][:length].tolist()

if reward > threshold:
positive_tokens.append(response_tokens)
else:
negative_tokens.append(response_tokens)

pos_avg_len = 0
neg_avg_len = 0

if positive_tokens:
pos_avg_len = sum(len(tokens) for tokens in positive_tokens) / len(positive_tokens)

if negative_tokens:
neg_avg_len = sum(len(tokens) for tokens in negative_tokens) / len(negative_tokens)

log_dict = {
f"{tag_prefix}/positive_avg_token_length": pos_avg_len,
f"{tag_prefix}/negative_avg_token_length": neg_avg_len
}

if metrics is not None:
metrics.update(log_dict)

if hasattr(self, 'tracker') and self.tracker:
self.tracker.log(log_dict, step=step)

return log_dict

def fit(self):
"""
The training loop of PPO.
Expand Down Expand Up @@ -968,9 +1027,41 @@
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
batch.batch["token_level_scores"] = reward_tensor

batch.batch["acc"] = batch.batch["token_level_scores"].sum(-1)

token_metrics = self.log_token_lengths(batch, self.global_steps, metrics)

Check failure on line 1032 in verl/trainer/ppo/ray_trainer.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (F841)

verl/trainer/ppo/ray_trainer.py:1032:25: F841 Local variable `token_metrics` is assigned to but never used

Check failure on line 1032 in verl/trainer/ppo/ray_trainer.py

View workflow job for this annotation

GitHub Actions / pre_commit_for_ppo (3.12)

Ruff (F841)

verl/trainer/ppo/ray_trainer.py:1032:25: F841 Local variable `token_metrics` is assigned to but never used

print(f"{list(reward_extra_infos_dict.keys())=}")

if reward_extra_infos_dict:
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
for key in ["length_penalty_applied", "length_penalty_alpha",
"original_rewards_mean", "penalized_rewards_mean", "penalty_ratio"]:
if key in reward_extra_infos_dict:
metrics[f"reward/length_penalty/{key}"] = reward_extra_infos_dict[key]

for key, value in reward_extra_infos_dict.items():
if isinstance(value, (int, float)): # Scalar values only
if key not in ["length_penalty_applied", "length_penalty_alpha",
"original_rewards_mean", "penalized_rewards_mean", "penalty_ratio"]:
if "score" in key:
metrics[f"reward/scores/{key}"] = value
elif "format" in key:
metrics[f"reward/format/{key}"] = value
elif "proof" in key:
metrics[f"reward/proof/{key}"] = value
else:
metrics[f"reward/other/{key}"] = value

safe_dict = {}
for k, v in reward_extra_infos_dict.items():
if isinstance(v, list) and len(v) == len(batch):
try:
safe_dict[k] = np.array(v)
except Exception as e:
print(f"Warning: Could not convert reward extra info '{k}' to numpy array: {e}")

if safe_dict:
batch.non_tensor_batch.update(safe_dict)

# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
Expand Down
54 changes: 53 additions & 1 deletion verl/trainer/ppo/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import os

import ray
import torch

from verl import DataProto
from verl.utils.length_penalty import apply_length_penalty


def get_custom_reward_fn(config):
Expand Down Expand Up @@ -75,6 +77,10 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
else:
raise NotImplementedError

# Pass length penalty config to reward manager
length_penalty_config = dict(config.reward_model.get("length_penalty", {}))
reward_kwargs.update({"length_penalty_config": length_penalty_config})

compute_score = get_custom_reward_fn(config)
return reward_manager_cls(
tokenizer=tokenizer,
Expand Down Expand Up @@ -102,6 +108,52 @@ def compute_reward(data: DataProto, reward_fn):
print(f"Error in reward_fn: {e}")
reward_tensor = reward_fn(data)
reward_extra_infos_dict = {}

# Apply length penalty if configured and if sequence lengths are available
if hasattr(reward_fn, "length_penalty_config") and reward_fn.length_penalty_config.get("enabled", False):
sequence_lengths = None
if "response_lengths" in data.batch:
sequence_lengths = data.batch["response_lengths"]
elif "attention_mask" in data.batch and "prompts" in data.batch:
prompt_len = data.batch["prompts"].shape[-1]
attention_mask = data.batch["attention_mask"]
total_lengths = torch.sum(attention_mask, dim=1)
sequence_lengths = total_lengths - prompt_len

if sequence_lengths is not None:
alpha = reward_fn.length_penalty_config.get("alpha", 0.0)
min_length = reward_fn.length_penalty_config.get("min_length", 0)
max_length = reward_fn.length_penalty_config.get("max_length", None)

orig_rewards = reward_tensor.sum(-1)

print(f"Applying length penalty with alpha={alpha}, min_length={min_length}, max_length={max_length}")
print(f"Mean sequence length: {sequence_lengths.float().mean().item():.1f}")
print(f"Original rewards mean: {orig_rewards.mean().item():.4f}")

penalized_rewards = apply_length_penalty(
orig_rewards,
sequence_lengths,
alpha=alpha,
min_length=min_length,
max_length=max_length
)

print(f"Penalized rewards mean: {penalized_rewards.mean().item():.4f}")
print(f"Penalty ratio: {(penalized_rewards / (orig_rewards + 1e-8)).mean().item():.4f}")

# Compute scaling factor for each sequence
scaling_factors = penalized_rewards / (orig_rewards + 1e-8) # Add small epsilon to avoid division by zero

# Apply scaling factors to token-level rewards
reward_tensor = reward_tensor * scaling_factors.unsqueeze(-1)

if reward_extra_infos_dict is not None:
reward_extra_infos_dict["length_penalty_applied"] = True
reward_extra_infos_dict["length_penalty_alpha"] = alpha
reward_extra_infos_dict["original_rewards_mean"] = orig_rewards.mean().item()
reward_extra_infos_dict["penalized_rewards_mean"] = penalized_rewards.mean().item()
reward_extra_infos_dict["penalty_ratio"] = (penalized_rewards / (orig_rewards + 1e-8)).mean().item()

return reward_tensor, reward_extra_infos_dict

Expand All @@ -113,4 +165,4 @@ def compute_reward_async(data: DataProto, config, tokenizer):
This is meant to be run in a separate Ray worker.
"""
reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}))
return compute_reward(data, reward_fn)
return compute_reward(data, reward_fn)
114 changes: 114 additions & 0 deletions verl/utils/length_penalty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2025 Individual Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np

def compute_length_penalty(sequence_lengths, alpha=0.0, min_length=0, max_length=None):
"""
Compute length penalty for sequence generation.

Args:
sequence_lengths: Tensor or numpy array of shape [batch_size] containing sequence lengths
alpha: Float controlling the strength and direction of the penalty:
- alpha > 0: Favor longer sequences
- alpha < 0: Favor shorter sequences
- alpha = 0: No length penalty
min_length: Minimum sequence length before applying penalty (no penalty below this)
max_length: Maximum sequence length (sequences longer than this get max penalty)

Returns:
Tensor or numpy array of shape [batch_size] containing length penalties
"""
if alpha == 0.0:
if isinstance(sequence_lengths, torch.Tensor):
return torch.ones_like(sequence_lengths, dtype=torch.float32)
else:
return np.ones_like(sequence_lengths, dtype=np.float32)

effective_lengths = sequence_lengths.copy() if isinstance(sequence_lengths, np.ndarray) else sequence_lengths.clone()

if isinstance(effective_lengths, np.ndarray):
effective_lengths = effective_lengths.astype(np.float32)
else:
effective_lengths = effective_lengths.float()

if min_length > 0:
if isinstance(effective_lengths, np.ndarray):
effective_lengths = np.maximum(effective_lengths - min_length, np.zeros_like(effective_lengths))
else:
effective_lengths = torch.maximum(effective_lengths - min_length,
torch.zeros_like(effective_lengths))

if max_length is not None:
if isinstance(effective_lengths, np.ndarray):
effective_lengths = np.minimum(effective_lengths, np.ones_like(effective_lengths) * (max_length - min_length))
else:
effective_lengths = torch.minimum(effective_lengths,
torch.ones_like(effective_lengths) * (max_length - min_length))

# Calculate penalty using the standard formula: ((5 + length)/6)^alpha
# This is similar to the formula used in Google Neural Machine Translation paper [https://arxiv.org/pdf/1609.08144]
if isinstance(effective_lengths, np.ndarray):
penalty = ((5.0 + effective_lengths) / 6.0) ** alpha
else:
penalty = ((5.0 + effective_lengths) / 6.0) ** alpha

return penalty

def apply_length_penalty(rewards, sequence_lengths, **kwargs):
"""
Apply length penalty to rewards.

Args:
rewards: List, Tensor or numpy array of shape [batch_size] containing the original rewards
sequence_lengths: List, Tensor or numpy array of shape [batch_size] containing sequence lengths
**kwargs: Parameters for compute_length_penalty

Returns:
List, Tensor or numpy array of shape [batch_size] with length-penalized rewards
"""
import numpy as np
import torch

is_tensor = isinstance(rewards, torch.Tensor)
is_list = isinstance(rewards, list)

if is_list:
rewards_np = np.array(rewards)
elif is_tensor:
rewards_np = rewards.cpu().numpy()
else:
rewards_np = rewards

if isinstance(sequence_lengths, torch.Tensor):
sequence_lengths_np = sequence_lengths.cpu().numpy()
elif isinstance(sequence_lengths, list):
sequence_lengths_np = np.array(sequence_lengths)
else:
sequence_lengths_np = sequence_lengths

length_penalties = compute_length_penalty(sequence_lengths_np, **kwargs)

penalized_rewards_np = rewards_np * length_penalties

if is_tensor:
device = rewards.device
penalized_rewards = torch.tensor(penalized_rewards_np, dtype=rewards.dtype, device=device)
elif is_list:
penalized_rewards = penalized_rewards_np.tolist()
else:
penalized_rewards = penalized_rewards_np

return penalized_rewards
6 changes: 4 additions & 2 deletions verl/workers/reward_manager/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
import torch

from verl import DataProto

from verl.utils.length_penalty import apply_length_penalty

class BatchRewardManager:
def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key="data_source", **reward_kwargs):
self.tokenizer = tokenizer
self.num_examine = num_examine
self.compute_score = compute_score
self.reward_fn_key = reward_fn_key
self.length_penalty_config = reward_kwargs.pop("length_penalty_config", {})
self.use_length_penalty = self.length_penalty_config.get("enabled", False)
self.reward_kwargs = reward_kwargs

def verify(self, data):
Expand Down Expand Up @@ -106,4 +108,4 @@ def __call__(self, data: DataProto, return_dict=False):
if return_dict:
return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info}
else:
return reward_tensor
return reward_tensor
Loading
Loading