diff --git a/examples/mixed_ppo_sft_sentiments.py b/examples/mixed_ppo_sft_sentiments.py new file mode 100644 index 000000000..5d2a30b87 --- /dev/null +++ b/examples/mixed_ppo_sft_sentiments.py @@ -0,0 +1,62 @@ +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset +# with a sentiment reward function +import json +import os +import sys +from typing import List + +import torch +from datasets import load_dataset +from transformers import pipeline + +import trlx +from trlx.data.default_configs import TRLConfig, default_ppo_config + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + + +def main(hparams={}): + # Merge sweep config with default config if given + config = TRLConfig.update(default_ppo_config().to_dict(), hparams) + config.method.rollouts_per_sft = 256 + config.method.sft_sample_updates = 32 + config.train.seq_length = 128 + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device=device, + ) + + def reward_fn(samples: List[str], **kwargs) -> List[float]: + sentiments = list(map(get_positive_score, sentiment_fn(samples))) + return sentiments + + # Take few words off of movies reviews as prompts + imdb = load_dataset("imdb", split="train+test") + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] + samples = [[" ".join(review.split()[:4]), " ".join(review.split()[4:])] for review in imdb["text"]] + + trlx.train( + reward_fn=reward_fn, + prompts=prompts, + samples=samples, + eval_prompts=["I don't know much about Hungarian underground"] * 256, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/trlx/data/accelerate_base_datatypes.py b/trlx/data/accelerate_base_datatypes.py index 838567e0e..8e9896ea9 100644 --- a/trlx/data/accelerate_base_datatypes.py +++ b/trlx/data/accelerate_base_datatypes.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Iterable +from typing import Iterable, Optional from torchtyping import TensorType @@ -14,10 +14,14 @@ class PromptElement: :param tokens: The prompt tokens. Should be a long tensor :type tokens: torch.Tensor + + :gt_response_tokens: The ground truth response tokens. Should be a long tensor. + :type gt_response_tokens: torch.Tensor """ text: str tokens: TensorType["num_tokens"] + gt_response_tokens: Optional[TensorType["response_length"]] = None @dataclass @@ -30,10 +34,14 @@ class PromptBatch: :param tokens: A long tensor batch of prompt tokens. :type tokens: torch.Tensor + + :gt_response_tokens: The ground truth response tokens. Should be a long tensor. + :type gt_response_tokens: torch.Tensor """ text: Iterable[str] tokens: TensorType["batch_size", "num_tokens"] + gt_response_tokens: Optional[TensorType["batch_size", "response_length"]] @dataclass diff --git a/trlx/data/ppo_types.py b/trlx/data/ppo_types.py index 375d7d3ab..99507b14a 100644 --- a/trlx/data/ppo_types.py +++ b/trlx/data/ppo_types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass - +from typing import Optional from torchtyping import TensorType @@ -34,7 +34,6 @@ class PPORLElement: values: TensorType["response_size"] rewards: TensorType["response_size"] - @dataclass class PPORLBatch: """ @@ -46,14 +45,14 @@ class PPORLBatch: :param response_tensors: A batch of response tensors. Should be a long tensor. :type response_tensors: torch.Tensor + :param gt_response_tensors: A batch of tensors corresponding to the ground truth responses. Should be a long tensor. + :type gt_response_tensors: torch.Tensor + :param logprobs: A batch of log probabilities from policy :type logprobs: torch.Tensor :param values: A batch of values from value network :type values: torch.Tensor - - :param rewards: A batch of rewards - :type rewards: torch.Tensor """ query_tensors: TensorType["batch_size", "query_size"] diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 82d3ec637..eceff3188 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -112,6 +112,9 @@ class PPOConfig(MethodConfig): :param gen_experience_kwargs: if this is not None, then the experience is generated using this :type gen_experience_kwargs: Dict[str, Any] + + :param mix_sft: if this is True, then SFT gradients will be mixed into PPO traininig + :type mix_sft: bool """ ppo_epochs: int @@ -131,6 +134,8 @@ class PPOConfig(MethodConfig): cliprange_reward: float gen_kwargs: dict gen_experience_kwargs: Optional[dict] = None + rollouts_per_sft: int = -1 + sft_sample_updates: Optional[int] = None def get_advantages_and_returns( self, diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index cee900cfc..bbe4df547 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -158,11 +158,9 @@ def __len__(self) -> int: def create_loader(self, batch_size: int, shuffle=False, sampler=None, drop_last=False) -> DataLoader: def collate_fn(xs): out = self.tokenizer.pad([{"input_ids": x["input_ids"]} for x in xs], return_tensors="pt") - for key in xs[0]: if key != "input_ids" and key != "attention_mask": out[key] = [x[key] for x in xs] - return out # Since all data is already pre-processed, no need to have diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 5c82335c0..0ec1110ed 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -141,6 +141,9 @@ def __init__(self, config, **kwargs): # noqa: C901 else: self.generate_sweep_kwarg = (k, v) + # Allows for flexible breaking of inner train loop + self.break_train = False + def setup_model(self): """ Returns a model derived from an instance's TRLConfig @@ -630,6 +633,10 @@ def learn(self): # noqa: C901 self.post_backward_callback() + if self.break_train: + self.break_train = False + break + self.post_epoch_callback() tbar.close() diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index a3af9aa3f..38b2802f5 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -20,7 +20,7 @@ AutoModelForSeq2SeqLMWithHydraValueHead, FixedKLController, ) -from trlx.pipeline.offline_pipeline import PromptPipeline +from trlx.pipeline.offline_pipeline import PromptPipeline, DialogStore from trlx.pipeline.ppo_pipeline import PPORolloutStorage from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer @@ -101,6 +101,12 @@ def __init__(self, config: TRLConfig, **kwargs): self.ref_mean = self.config.method.ref_mean self.ref_std = self.config.method.ref_std + # Set training mode conditions + self.num_sampled_rollouts = 0 + self.rollouts_per_sft = self.config.method.rollouts_per_sft + self.mix_sft = self.config.method.rollouts_per_sft > 0 + self.sft = False + def get_arch(self, config: TRLConfig): """Get the model""" model_class = AutoModelForCausalLMWithHydraValueHead @@ -124,72 +130,83 @@ def loss(self, batch: PPORLBatch): Args: batch: Previous batch of episodes """ - # Move `batch` data to `accelerator` device - query_tensors = batch.query_tensors.to(self.accelerator.device) - response_tensors = batch.response_tensors.to(self.accelerator.device) - old_logprobs = batch.logprobs.to(self.accelerator.device) - old_values = batch.values.to(self.accelerator.device) - old_rewards = batch.rewards.to(self.accelerator.device) - response_length = old_rewards.shape[1] - - advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length) - - if self.config.model.model_arch_type == "seq2seq": - input_ids = query_tensors - decoder_input_ids = response_tensors - attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) - decoder_attention_mask = ( - decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) - ) - decoder_attention_mask[:, 0] = 1 - - # Forward pass - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - ) + # Case on type of loss + if not self.sft: + # Move `batch` data to `accelerator` device + query_tensors = batch.query_tensors.to(self.accelerator.device) + response_tensors = batch.response_tensors.to(self.accelerator.device) + old_logprobs = batch.logprobs.to(self.accelerator.device) + old_values = batch.values.to(self.accelerator.device) + old_rewards = batch.rewards.to(self.accelerator.device) + response_length = old_rewards.shape[1] + + advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length) + + if self.config.model.model_arch_type == "seq2seq": + input_ids = query_tensors + decoder_input_ids = response_tensors + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + decoder_attention_mask = ( + decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + ) + decoder_attention_mask[:, 0] = 1 + + # Forward pass + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) - logits = outputs.logits - values_pred = outputs.value - logprobs = logprobs_of_labels(logits[:, :-1, :], decoder_input_ids[:, 1:]) - mask = decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) - start = 0 - end = start + response_length - logprobs, values_pred, mask = ( - logprobs[:, start:end], - values_pred[:, start:end], - mask[:, start + 1 : end + 1], + logits = outputs.logits + values_pred = outputs.value + logprobs = logprobs_of_labels(logits[:, :-1, :], decoder_input_ids[:, 1:]) + mask = decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + start = 0 + end = start + response_length + logprobs, values_pred, mask = ( + logprobs[:, start:end], + values_pred[:, start:end], + mask[:, start:end], + ) + else: + tokens = torch.cat((query_tensors, response_tensors), dim=1) + attention_mask = tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device) + outputs = self.model(tokens, attention_mask, return_dict=True) + logits = outputs.logits + values_pred = outputs.value + values_pred = values_pred[:, :-1] + logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:]) + + start = query_tensors.shape[1] - 1 + end = start + response_length + logprobs, values_pred, mask = ( + logprobs[:, start:end], + values_pred[:, start:end], + attention_mask[:, start:end], + ) + + loss, stats = self.config.method.loss( + logprobs=logprobs, + values=values_pred, + old_logprobs=old_logprobs, + old_values=old_values, + advantages=advantages, + returns=returns, + mask=mask, ) else: - tokens = torch.cat((query_tensors, response_tensors), dim=1) - attention_mask = tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - outputs = self.model(tokens, attention_mask, return_dict=True, position_ids=position_ids) - logits = outputs.logits - values_pred = outputs.value - values_pred = values_pred[:, :-1] - logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:]) - - start = query_tensors.shape[1] - 1 - end = start + response_length - logprobs, values_pred, mask = ( - logprobs[:, start:end], - values_pred[:, start:end], - attention_mask[:, start + 1 : end + 1], - ) + # Trainer in sft mode + if "labels" in batch: + labels = batch.labels.clone() + else: + labels = batch.input_ids.clone() + labels[~batch.attention_mask.bool()] = -100 - loss, stats = self.config.method.loss( - logprobs=logprobs, - values=values_pred, - old_logprobs=old_logprobs, - old_values=old_values, - advantages=advantages, - returns=returns, - mask=mask, - ) + # TODO(dahoas): Fix. This may break with zero3. + loss = self.model.base_model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=labels).loss + stats = {"sft_loss": loss.item()} return loss, stats @@ -211,22 +228,44 @@ def post_epoch_callback(self): Clears the store and creates `num_rollouts` new episodes. """ + # Log and clear rollouts if self.log_rollouts: self.store.export_history(location=self.rollout_logging_dir) self.store.clear_history() - # Collect more rollouts for training - self.make_experience(self.config.method.num_rollouts, self.iter_count) + + # Case on whether in SFT mode or RL mode + if self.mix_sft and not self.sft and self.num_sampled_rollouts % self.rollouts_per_sft == 0 and self.num_sampled_rollouts > 0: + logger.info("Mixing in SFT grads") + self.sft = True + self.cur_sft_updates = 0 + self.n_updates_per_batch = 1 + self.train_dataloader = self.sft_dataloader + else: + # Collect more rollouts for RL training + self.sft = False + self.n_updates_per_batch = self.config.method.ppo_epochs + self.train_dataloader = self.ppo_dataloader + self.make_experience(self.config.method.num_rollouts, self.iter_count) def post_backward_callback(self): - self.kl_ctl.update(self.mean_kl, n_steps=self.config.train.batch_size) + if self.sft: + self.cur_sft_updates += self.config.train.batch_size + if self.cur_sft_updates >= self.config.method.sft_sample_updates: + self.break_train = True + else: + self.kl_ctl.update(self.mean_kl, n_steps=self.config.train.batch_size) def prepare_learning(self): eval_dataloader = self.eval_pipeline.create_loader(self.config.method.chunk_size) self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader) - self.make_experience(self.config.method.num_rollouts) + self.ppo_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=True) + self.train_dataloader = self.ppo_dataloader + if self.mix_sft: + sft_dataloader = self.sft_store.create_loader(self.config.train.batch_size) + self.sft_dataloader = self.accelerator.prepare_data_loader(sft_dataloader) - self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=False) + self.make_experience(self.config.method.num_rollouts) self.n_updates_per_batch = self.config.method.ppo_epochs self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) @@ -238,6 +277,10 @@ def add_prompt_pipeline(self, pipeline: PromptPipeline): prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) self.prompt_iterator = infinite_dataloader(prompt_dataloader) + def add_sft_store(self, store: DialogStore): + """Add a DialogStore as an SFT store if mixing in SFT gradients with RL gradients""" + self.sft_store = store + def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa: """Make experiences @@ -482,6 +525,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq stats["kl_ctl_value"] = self.kl_ctl.value self.mean_kl = stats["policy/sqrt_kl"] ** 2 self.accelerator.log(stats, step=iter_count) + self.num_sampled_rollouts += num_rollouts # Push samples and rewards to trainer's rollout storage self.push_to_store(ppo_rl_elements) diff --git a/trlx/trlx.py b/trlx/trlx.py index 7fbce94f4..9297fd316 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -8,6 +8,7 @@ default_ppo_config, default_sft_config, ) +from trlx.pipeline.offline_pipeline import DialogStore, tokenize_dialogue from trlx.utils import set_seed from trlx.utils.loading import get_pipeline, get_trainer @@ -99,6 +100,12 @@ def train( # noqa: C901 ) trainer.add_prompt_pipeline(pipeline) + # Add sft pipeline if mixing in sft gradients during RL + if config.method.rollouts_per_sft > 0: + dialogs = [tokenize_dialogue(d, trainer.tokenizer, config.train.seq_length) for d in samples] + sft_store = DialogStore(dialogs, trainer.tokenizer) + trainer.add_sft_store(sft_store) + if eval_prompts is None: eval_prompts = prompts[:batch_size]