Skip to content

Feature: Implementing SFT mixing with PPO #525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/mixed_ppo_sft_sentiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Generates positive movie reviews by tuning a pretrained model on IMDB dataset
# with a sentiment reward function
import json
import os
import sys
from typing import List

import torch
from datasets import load_dataset
from transformers import pipeline

import trlx
from trlx.data.default_configs import TRLConfig, default_ppo_config


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


def main(hparams={}):
# Merge sweep config with default config if given
config = TRLConfig.update(default_ppo_config().to_dict(), hparams)
config.method.rollouts_per_sft = 256
config.method.sft_sample_updates = 32
config.train.seq_length = 128

if torch.cuda.is_available():
device = int(os.environ.get("LOCAL_RANK", 0))
else:
device = -1

sentiment_fn = pipeline(
"sentiment-analysis",
"lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=256,
device=device,
)

def reward_fn(samples: List[str], **kwargs) -> List[float]:
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
return sentiments

# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]
samples = [[" ".join(review.split()[:4]), " ".join(review.split()[4:])] for review in imdb["text"]]

trlx.train(
reward_fn=reward_fn,
prompts=prompts,
samples=samples,
eval_prompts=["I don't know much about Hungarian underground"] * 256,
config=config,
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
10 changes: 9 additions & 1 deletion trlx/data/accelerate_base_datatypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Iterable
from typing import Iterable, Optional

from torchtyping import TensorType

Expand All @@ -14,10 +14,14 @@ class PromptElement:

:param tokens: The prompt tokens. Should be a long tensor
:type tokens: torch.Tensor

:gt_response_tokens: The ground truth response tokens. Should be a long tensor.
:type gt_response_tokens: torch.Tensor
"""

text: str
tokens: TensorType["num_tokens"]
gt_response_tokens: Optional[TensorType["response_length"]] = None


@dataclass
Expand All @@ -30,10 +34,14 @@ class PromptBatch:

:param tokens: A long tensor batch of prompt tokens.
:type tokens: torch.Tensor

:gt_response_tokens: The ground truth response tokens. Should be a long tensor.
:type gt_response_tokens: torch.Tensor
"""

text: Iterable[str]
tokens: TensorType["batch_size", "num_tokens"]
gt_response_tokens: Optional[TensorType["batch_size", "response_length"]]


@dataclass
Expand Down
9 changes: 4 additions & 5 deletions trlx/data/ppo_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass

from typing import Optional
from torchtyping import TensorType


Expand Down Expand Up @@ -34,7 +34,6 @@ class PPORLElement:
values: TensorType["response_size"]
rewards: TensorType["response_size"]


@dataclass
class PPORLBatch:
"""
Expand All @@ -46,14 +45,14 @@ class PPORLBatch:
:param response_tensors: A batch of response tensors. Should be a long tensor.
:type response_tensors: torch.Tensor

:param gt_response_tensors: A batch of tensors corresponding to the ground truth responses. Should be a long tensor.
:type gt_response_tensors: torch.Tensor

:param logprobs: A batch of log probabilities from policy
:type logprobs: torch.Tensor

:param values: A batch of values from value network
:type values: torch.Tensor

:param rewards: A batch of rewards
:type rewards: torch.Tensor
"""

query_tensors: TensorType["batch_size", "query_size"]
Expand Down
5 changes: 5 additions & 0 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class PPOConfig(MethodConfig):

:param gen_experience_kwargs: if this is not None, then the experience is generated using this
:type gen_experience_kwargs: Dict[str, Any]

:param mix_sft: if this is True, then SFT gradients will be mixed into PPO traininig
:type mix_sft: bool
"""

ppo_epochs: int
Expand All @@ -131,6 +134,8 @@ class PPOConfig(MethodConfig):
cliprange_reward: float
gen_kwargs: dict
gen_experience_kwargs: Optional[dict] = None
rollouts_per_sft: int = -1
sft_sample_updates: Optional[int] = None

def get_advantages_and_returns(
self,
Expand Down
2 changes: 0 additions & 2 deletions trlx/pipeline/offline_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,9 @@ def __len__(self) -> int:
def create_loader(self, batch_size: int, shuffle=False, sampler=None, drop_last=False) -> DataLoader:
def collate_fn(xs):
out = self.tokenizer.pad([{"input_ids": x["input_ids"]} for x in xs], return_tensors="pt")

for key in xs[0]:
if key != "input_ids" and key != "attention_mask":
out[key] = [x[key] for x in xs]

return out

# Since all data is already pre-processed, no need to have
Expand Down
7 changes: 7 additions & 0 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def __init__(self, config, **kwargs): # noqa: C901
else:
self.generate_sweep_kwarg = (k, v)

# Allows for flexible breaking of inner train loop
self.break_train = False

def setup_model(self):
"""
Returns a model derived from an instance's TRLConfig
Expand Down Expand Up @@ -630,6 +633,10 @@ def learn(self): # noqa: C901

self.post_backward_callback()

if self.break_train:
self.break_train = False
break

self.post_epoch_callback()
tbar.close()

Expand Down
Loading