Skip to content

Modified GRPOTrainer to accumulate gradient within a single training batch #3288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jarrelscy
Copy link

@jarrelscy jarrelscy commented Apr 13, 2025

What does this PR do?

GRPOTrainer calculates advantages and then calculates loss per completion. Currently this is all done within a single batch which can take a lot of memory. Just like with gradient accumulation, we can call .backwards on the loss for each completion separately. This PR does so by introducing a new parameter into GRPOConfig called num_generations_chunks, of which num_generations needs to be a multiple of. Doing so will cause loss.backward to be called per num_generations_chunks number of completions.

Example usage:

# train_grpo.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", 
                          logging_steps=10,
                          per_device_train_batch_size=16, # needs to be a multiple of num_generations
                          num_generations=8, # needs to be a multiple of num_generations_chunks 
                          num_generations_chunks=8)
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

Fixes # 3017

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qgallouedec
Copy link
Member

Thanks @jarrelscy

I understand the motivation. Just for clarification, if

Just like with gradient accumulation, we can call .backwards on the loss for each completion separately.

then why not using gradient accumulation? Is it because the generation will also be done on smaller batches, then makes things slower?

@jarrelscy
Copy link
Author

Hi @qgallouedec as @JamesBowerXanda pointed out in here, the quality of the loss depends on the group size. In this paper they point that you need a large group size to approximate the expected reward normalised by the standard deviation of the reward of an output sampled from the previous policy.

In GRPO each generation is assigned a relative advantage against other generations, so if the group size is small, this can lead to erratic losses.

In gradient accumulation (per batch), we are still comparing the advantage of each generation against other generations within that batch.

@qgallouedec
Copy link
Member

qgallouedec commented Apr 21, 2025

FYI, now you can pass a group as large as gradient_accumulation * per_device_batch_size * num_devices thanks to #3283

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants