diff --git a/examples/mrt_translation_t5.py b/examples/mrt_translation_t5.py new file mode 100644 index 000000000..da9fff647 --- /dev/null +++ b/examples/mrt_translation_t5.py @@ -0,0 +1,210 @@ +"""Example of using PPO to train a T5 model for translation. +Based on examples/summarize_daily_cnn/t5_summarize_daily_cnn.py""" + +import json +import os +import sys +from typing import List + +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +import trlx +from trlx.data.configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.models.modeling_mrt import MRTConfig + +try: + import comet + import evaluate + + if comet.__version__ != "1.1.3": + raise ImportError +except ImportError: + raise ImportError( + "To run this example, please install `evaluate`, `nltk` and `comet==1.1.3` packages by " + "running `pip install evaluate unbabel-comet==1.1.3`" + ) + + +default_config = TRLConfig( + train=TrainConfig( + seq_length=612, + epochs=100, + total_steps=100000, + batch_size=4, + checkpoint_interval=10000, + eval_interval=64, + pipeline="PromptPipeline", + trainer="AccelerateMRTTrainer", + # tracker=None + tracker="wandb", + ), + model=ModelConfig( + model_path="t5-small", + model_arch_type="seq2seq", + num_layers_unfrozen=-1, + ), + tokenizer=TokenizerConfig( + tokenizer_path="t5-small", + truncation_side="right", + ), + optimizer=OptimizerConfig( + name="adamw", + kwargs={ + "lr": 2.0e-6, + "betas": [0.9, 0.999], + "eps": 1.0e-8, + "weight_decay": 1.0e-6, + }, + ), + scheduler=SchedulerConfig( + name="cosine_annealing", + kwargs={ + "T_max": 10000, + "eta_min": 1.0e-6, + }, + ), + method=MRTConfig( + name="MRTConfig", + num_rollouts=512, + chunk_size=4, + num_candidates=16, + ce_loss_weight=0.0, + scale_reward=None, + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs={ # for evaluation + "max_new_tokens": 100, + # TODO: what should the defaults here be + }, + gen_experience_kwargs={ # for rollouts + "max_new_tokens": 100, + "num_beams": 16, # should be same as nb_candidates + "num_return_sequences": 16, # should be same as nb_candidates + "do_sample": False, + "temperature": 1.0, + # "top_k": 50, + # "top_p": 0.95, + }, + ), +) + + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + + # COMET is the metric we are optimizng for + comet_metric = evaluate.load("comet", "wmt20-comet-da", progress_bar=False) + bleu_metric = evaluate.load("bleu") + chrf_metric = evaluate.load("chrf") + + os.environ["TOKENIZERS_PARALLELISM"] = "false" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]) -> List[float]: + original_sents = [translation_map[prompt.strip()] for prompt in prompts] + + scores = comet_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + sources=[original["src"] for original in original_sents], + )["scores"] + + # TODO: This is needed since there seems to be a bug in the comet metric + # that changes torch's determinism setting. Remove this once the bug is fixed. + torch.use_deterministic_algorithms(False, warn_only=True) + return scores + + def metric_fn(samples: List[str], prompts: List[str], outputs: List[str]) -> List[float]: + """Compute COMET, BLEU and CHRF for evaluation""" + original_sents = [translation_map[prompt.strip()] for prompt in prompts] + + comet_score = comet_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + sources=[original["src"] for original in original_sents], + )["mean_score"] + + bleu_score = bleu_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + )["bleu"] + + chrf_score = chrf_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + )["score"] + + # TODO: This is needed since there seems to be a bug in the comet metric + # that changes torch's determinism setting. Remove this once the bug is fixed. + # Same issue as in `reward_fn` + torch.use_deterministic_algorithms(False, warn_only=True) + + # For corpus-level metrics, it's better to ignore the sentence-level scores + return {"bleu": bleu_score, "chrf": chrf_score, "comet": comet_score} + + # The WMT16 is large so we can benefit with using it as a streaming dataset + train_dataset = load_dataset("wmt16", "de-en", split="train", streaming=True) + valid_dataset = load_dataset("wmt16", "de-en", split="validation", streaming=True) + + src_lang = "en" + tgt_lang = "de" + PREFIX = "translate English to German: " + + # take 20,000 samples from the training set as prompts for training + # TODO: update to 20k + original_src_dataset = [sent_pair["translation"][src_lang] for sent_pair in train_dataset.take(1200)] + tgt_dataset = [sent_pair["translation"][tgt_lang] for sent_pair in train_dataset.take(1200)] + src_dataset = [PREFIX + src_sent for src_sent in original_src_dataset] + + # take 1,000 samples from the validation set as prompts for evaluation + val_original_src_dataset = [sent_pair["translation"][src_lang] for sent_pair in valid_dataset.take(1000)] + val_tgt_dataset = [sent_pair["translation"][tgt_lang] for sent_pair in valid_dataset.take(1000)] + val_src_dataset = [PREFIX + src_sent for src_sent in val_original_src_dataset] + + # make dictionary of prompts and labels to use for reward function + tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) + tokenizer.padding_side = "left" + tokenizer.truncation_side = "right" + tokenizer.sep_token = "" + max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] + translation_map = {} + + for i in tqdm(range(len(original_src_dataset))): + key = tokenizer.decode( + tokenizer(src_dataset[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + translation_map[key.strip()] = {"src": original_src_dataset[i], "tgt": tgt_dataset[i]} + + for i in tqdm(range(len(val_original_src_dataset))): + key = tokenizer.decode( + tokenizer(val_src_dataset[i], truncation=True, max_length=max_length, add_special_tokens=False)[ + "input_ids" + ], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + translation_map[key.strip()] = {"src": val_original_src_dataset[i], "tgt": val_tgt_dataset[i]} + + trlx.train( + reward_fn=reward_fn, + metric_fn=metric_fn, + prompts=src_dataset, + eval_prompts=val_src_dataset, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/examples/summarize_daily_cnn/debug_mrt_summarize_daily_cnn_t5.py b/examples/summarize_daily_cnn/debug_mrt_summarize_daily_cnn_t5.py new file mode 100644 index 000000000..ab9c8e1d8 --- /dev/null +++ b/examples/summarize_daily_cnn/debug_mrt_summarize_daily_cnn_t5.py @@ -0,0 +1,174 @@ +# DO NOT REVIEW, WILL BE DELETED +from typing import List + +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +import trlx +from trlx.data.configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.models.modeling_mrt import MRTConfig + +try: + import evaluate +except ImportError: + raise ImportError( + "To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`" + ) + +config = TRLConfig( + train=TrainConfig( + seq_length=612, + epochs=100, + total_steps=100000, + batch_size=4, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateMRTTrainer", + tracker=None + # tracker="wandb", + ), + model=ModelConfig( + model_path="google/flan-t5-small", + model_arch_type="seq2seq", + num_layers_unfrozen=2, + ), + tokenizer=TokenizerConfig( + tokenizer_path="google/flan-t5-small", # change to reasonable value + truncation_side="right", # what is this? + ), + optimizer=OptimizerConfig( + name="adamw", + kwargs={ + "lr": 1.0e-5, + "betas": [0.9, 0.999], + "eps": 1.0e-8, + "weight_decay": 1.0e-6, + }, + ), + scheduler=SchedulerConfig( + name="cosine_annealing", + kwargs={ + "T_max": 10000, + "eta_min": 1.0e-6, + }, + ), + method=MRTConfig( + name="MRTConfig", + # n_updates_per_batch=1, #### MRT + num_rollouts=512, + chunk_size=4, + ppo_epochs=1, + # init_kl_coef=0.05, + # target=6, + # horizon=10000, + # gamma=0.99, + # lam=0.95, + # cliprange=0.2, + # cliprange_value=0.2, + # vf_coef=1.0, + num_candidates=16, + ce_loss_weight=0.0, + scale_reward=None, + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs={ # for evaluation + "max_new_tokens": 100, + # TODO: what should the defaults here be + }, + gen_experience_kwargs={ # for rollouts + "max_new_tokens": 100, + "num_beams": 16, # should be same as nb_candidates + "num_return_sequences": 16, # should be same as nb_candidates + "do_sample": False, + "temperature": 1.0, + # "top_k": 50, + # "top_p": 0.95, + }, + ), +) + +# gen_kwargs = { +# "min_length":-1, +# "top_k": config['top_k'], +# "top_p": 1.0, +# "temperature": config["temperature"], +# "do_sample": config['do_sample'], +# "num_beams": config['num_beams'], +# "max_length": config['max_length'], +# # "pad_token_id": model.eos_token_id, +# "num_return_sequences": config['candidate_size'], +# } +# eval_kwargs = { +# # "early_stopping": True, +# # "length_penalty": 2.0, +# "min_length":-1, +# "top_k": 0.0, +# # "top_p": 1.0, +# "do_sample": False, +# "num_beams": config['eval_num_beams'], +# # "no_repeat_ngram_size": 3, +# "max_length": config['max_length'], +# } + + +meteor = evaluate.load("meteor") # use meteor as the reward function + +if __name__ == "__main__": + + def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): + original_summaries = [prompt_label[prompt.strip()] for prompt in prompts] + scores = [ + meteor.compute(predictions=[output.strip()], references=[original])["meteor"] + for (original, output) in zip(original_summaries, outputs) + ] + return scores + + dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data") + + # take 20,000 samples from the training set as prompts for training + prompts = dataset["train"]["article"][0:1200] + summaries = dataset["train"]["highlights"][0:1200] + prompts = ["Summarize: " + prompt for prompt in prompts] + + # take 1,000 samples from the validation set as prompts for evaluation + val_prompts = ["Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]] + val_summaries = dataset["validation"]["highlights"][0:1000] + + # make dictionary of prompts and labels to use for reward function + tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) + tokenizer.padding_side = "left" + tokenizer.truncation_side = "right" + tokenizer.sep_token = "" + prompt_label = {} + max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] + + for i in tqdm(range(len(prompts))): + key = tokenizer.decode( + tokenizer(prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + prompt_label[key.strip()] = summaries[i] + + for i in tqdm(range(len(val_prompts))): + key = tokenizer.decode( + tokenizer(val_prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + prompt_label[key.strip()] = val_summaries[i] + + trlx.train( + reward_fn=reward_fn, + prompts=prompts, + eval_prompts=val_prompts, + config=config, + ) diff --git a/trlx/data/mrt_types.py b/trlx/data/mrt_types.py new file mode 100644 index 000000000..e8c70b0d3 --- /dev/null +++ b/trlx/data/mrt_types.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass + +from torchtyping import TensorType + + +@dataclass +class MRTRLElement: + """ + :param query_tensor: The query tensor i.e. the prompt tokens. + Should be a long tensor. + :type query_tensor: torch.Tensor + + :param response_tensor: The response tensor i.e. the output tokens. + Should be a long tensor. + :type response_tensor: torch.Tensor + + :param logprobs: The log probabilities over all tokens in the vocabulary for + each token generated from the policy network + (i.e. the autoregressive model). + Should be a float tensor of same size as tokens, + with a dimension across the vocabulary. + :type logprobs: torch.Tensor + + :param values: The values for each token generated from the value network or value head. + Should be a float tensor of same size as tokens. + :type values: torch.Tensor + + :param rewards: The rewards for each token outputted in response. + Should be a float tensor of same size as tokens. + :type rewards: torch.Tensor + """ + + query_tensor: TensorType["num_candidates", "query_size"] + response_tensor: TensorType["num_candidates", "response_size"] + logprobs: TensorType["num_candidates", "response_size", "vocab_size"] + values: TensorType["num_candidates", "response_size"] + rewards: TensorType["num_candidates", "response_size"] + + +@dataclass +class MRTRLBatch: + """ + A batched version of the MRTRLElement. See MRTRLElement for more details on individual fields. + + :param query_tensors: A batch of query tensors. Should be a long tensor. + :type query_tensors: torch.Tensor + + :param response_tensors: A batch of response tensors. Should be a long tensor. + :type response_tensors: torch.Tensor + + :param logprobs: A batch of log probabilities from policy + :type logprobs: torch.Tensor + + :param values: A batch of values from value network + :type values: torch.Tensor + + :param rewards: A batch of rewards + :type rewards: torch.Tensor + """ + + query_tensors: TensorType["batch_size", "num_candidates", "query_size"] + response_tensors: TensorType["batch_size", "num_candidates", "response_size"] + logprobs: TensorType["batch_size", "num_candidates", "response_size", "vocab_size"] + values: TensorType["batch_size", "num_candidates", "response_size"] + rewards: TensorType["batch_size", "num_candidates", "response_size"] diff --git a/trlx/models/modeling_mrt.py b/trlx/models/modeling_mrt.py new file mode 100644 index 000000000..42d61a48a --- /dev/null +++ b/trlx/models/modeling_mrt.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torchtyping import TensorType + +from trlx.data.method_configs import MethodConfig, register_method +from trlx.utils.modeling import flatten_dict + + +@dataclass +@register_method +class MRTConfig(MethodConfig): + """ + Config for MRT method + + :param num_rollouts: Number of experiences to observe before learning + :type num_rollouts: int + + :param gamma: Discount factor + :type gamma: float + + :param gen_kwargs: Additioanl kwargs for the generation + :type gen_kwargs: Dict[str, Any] + + :param gen_experience_kwargs: if this is not None, then the experience is generated using this + :type gen_experience_kwargs: Dict[str, Any] + """ + + num_rollouts: int + chunk_size: int + num_candidates: int + ce_loss_weight: float + scale_reward: Optional[str] + ref_mean: Optional[float] + ref_std: Optional[float] + cliprange_reward: float + gen_kwargs: dict + gen_experience_kwargs: Optional[dict] = None + + def loss( + self, + logprobs: TensorType["batch_size", "response_size"], + rewards: TensorType["batch_size", "response_size"], + mask: TensorType["batch_size", "response_size"], + ): + """MRT objective function.""" + + # TODO: check if masking is correct + + n = mask.sum() + + loss = torch.tensor(0.0) + + # we make the assumption here that we only care about sequence level rewards only + rewards = rewards.sum(dim=-1) + costs = 1 - rewards + + # Reward component + if self.ce_loss_weight < 1.0: # if ce_loss_weight is 1.0, then we only use the ce loss + # We make the assumption here that rewards are scaled to [0,1] + # lengths = response_masks.sum(dim=-1).float() + lengths = mask.sum(dim=-1).float() + + # model_outputs = self.model( + # input_ids=queries, + # decoder_input_ids=responses, # response tokens are already shifted right + # attention_mask=query_masks + # return_dict=True) + # , + avg_scores = logprobs.sum(dim=-1) / lengths + + # [batch_size, candidate_size] + avg_scores = avg_scores.view(-1, self.num_candidates) + costs = costs.view(-1, self.num_candidates) + + probs = F.softmax(avg_scores, dim=1).squeeze(-1) + loss = (probs * costs).sum() + + # Cross entropy component + ce_loss = torch.tensor(0.0) + if self.ce_loss_weight > 0.0: + # TODO: for this to work we need to have some sort of reference + assert False, "ce_loss_weight should be 0.0" + # if parallel_mask is not None: + # queries = queries[parallel_mask] + # query_masks = query_masks[parallel_mask] + # refs = refs[parallel_mask] + # ref_masks = ref_masks[parallel_mask] + + # # We should compute the cross entropy with the reference response and not with the generated response + # model_outputs = self.model( + # input_ids=queries, + # decoder_input_ids= + # shift_tokens_right(refs, self.model.config.pad_token_id, self.model.config.decoder_start_token_id), + # attention_mask=query_masks, + # return_dict=True) + # ce_loss = F.cross_entropy( + # model_outputs.logits.reshape(-1, model_outputs.logits.size(-1)), + # refs.reshape(-1), + # ignore_index=self.model.config.pad_token_id + # ) + + combined_loss = self.ce_loss_weight * ce_loss + (1 - self.ce_loss_weight) * loss + + stats = dict( + losses=dict( + total_loss=combined_loss.item(), + ce_loss=ce_loss.item(), + mrt_loss=loss.item(), + ), + padding_percentage=n / mask.numel(), + ) + + return combined_loss, flatten_dict(stats) diff --git a/trlx/pipeline/mrt_pipeline.py b/trlx/pipeline/mrt_pipeline.py new file mode 100644 index 000000000..6c2948d17 --- /dev/null +++ b/trlx/pipeline/mrt_pipeline.py @@ -0,0 +1,82 @@ +import json +import os +import time +from typing import Iterable + +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader + +from trlx.data.mrt_types import MRTRLBatch, MRTRLElement +from trlx.pipeline import BaseRolloutStore + + +class MRTRolloutStorage(BaseRolloutStore): + """ + Rollout storage for training MRT + """ + + def __init__(self, pad_token_id): + super().__init__() + + self.pad_token_id = pad_token_id + self.history: Iterable[MRTRLElement] = [None] + + def push(self, exps: Iterable[MRTRLElement]): + self.history += exps + + def clear_history(self): + self.history = [] + + def export_history(self, location: str): + assert os.path.exists(location) + + fpath = os.path.join(location, f"epoch-{str(time.time())}.json") + + def exp_to_dict(exp): + {k: v.cpu().tolist() for k, v in exp.__dict__.items()} + + data = [exp_to_dict(exp) for exp in self.history] + with open(fpath, "w") as f: + f.write(json.dumps(data, indent=2)) + + def __getitem__(self, index: int) -> MRTRLElement: + return self.history[index] + + def __len__(self) -> int: + return len(self.history) + + def create_loader( + self, + batch_size: int, + shuffle: bool, + ) -> DataLoader: + def collate_fn(elems: Iterable[MRTRLElement]): + return MRTRLBatch( # TODO: make sure this is expected + pad_sequence( + [elem.query_tensor.transpose(0, 1) for elem in elems], + padding_value=self.pad_token_id, + ) + .transpose(0, 1) + .transpose(1, 2), + # Right pad the rest, to have a single horizontal query/response split + pad_sequence( + [elem.response_tensor.transpose(0, 1) for elem in elems], + padding_value=self.pad_token_id, + ) + .transpose(0, 1) + .transpose(1, 2), + pad_sequence( + [elem.logprobs.transpose(0, 1) for elem in elems], + padding_value=0.0, + ) + .transpose(0, 1) + .transpose(1, 2), + pad_sequence([elem.values.transpose(0, 1) for elem in elems], padding_value=0.0) + .transpose(0, 1) + .transpose(1, 2), + pad_sequence([elem.rewards.transpose(0, 1) for elem in elems], padding_value=0.0) + .transpose(0, 1) + .transpose(1, 2), + ) + + return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=collate_fn) diff --git a/trlx/reference.py b/trlx/reference.py index dab6b6d97..c4f4612a7 100644 --- a/trlx/reference.py +++ b/trlx/reference.py @@ -4,9 +4,10 @@ import os import subprocess -import wandb import wandb.apis.reports as wb +import wandb + parser = argparse.ArgumentParser() parser.add_argument("branch", type=str, help="Git branch in the format `origin:branch`") parser.add_argument("--against", type=str, default="CarperAI/trlx:main", help="Reference git branch") diff --git a/trlx/sweep.py b/trlx/sweep.py index 615cb7361..9bfc07495 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -5,7 +5,6 @@ from datetime import datetime import ray -import wandb import wandb.apis.reports as wb import yaml from ray import tune @@ -13,6 +12,8 @@ from ray.train.huggingface.accelerate import AccelerateTrainer from ray.tune.logger import CSVLoggerCallback +import wandb + def get_param_space(config: dict): # noqa: C901 """Get the param space from the config file.""" diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index e2f150e1b..7c0fe6dab 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -490,6 +490,7 @@ def learn(self): # noqa: C901 else: results = self.evaluate() self.accelerator.log(results, step=self.iter_count) + ... tbar = logging.tqdm( initial=self.iter_count, diff --git a/trlx/trainer/accelerate_mrt_trainer.py b/trlx/trainer/accelerate_mrt_trainer.py new file mode 100644 index 000000000..c03bb4338 --- /dev/null +++ b/trlx/trainer/accelerate_mrt_trainer.py @@ -0,0 +1,502 @@ +import json +import os +import uuid +from time import time +from typing import Callable, List + +import ray +import torch +import torch.nn.functional as F +import transformers +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +import trlx.utils.logging as logging +from trlx.data.accelerate_base_datatypes import PromptBatch +from trlx.data.configs import TRLConfig +from trlx.data.mrt_types import MRTRLBatch, MRTRLElement +from trlx.models.modeling_ppo import ( # TODO: do we need to update this to MRT? + AutoModelForCausalLMWithHydraValueHead, + AutoModelForSeq2SeqLMWithHydraValueHead, +) +from trlx.pipeline.mrt_pipeline import MRTRolloutStorage +from trlx.pipeline.offline_pipeline import PromptPipeline +from trlx.trainer import register_trainer +from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer +from trlx.utils import Clock +from trlx.utils.modeling import RunningMoments, logprobs_of_labels + +logger = logging.get_logger(__name__) + + +@register_trainer +class AccelerateMRTTrainer(AccelerateRLTrainer): + """MRT Accelerate Trainer""" + + reward_fn: Callable[[List[str], List[str], List[str]], List[float]] + tokenizer: AutoTokenizer + + def __init__(self, config: TRLConfig, **kwargs): + """MRT Accelerate Trainer initialization + + Args: + config: Config + """ + super().__init__(config, **kwargs) + + # Setup rollout logging + if config.train.rollout_logging_dir is not None: + self.log_rollouts = True + self.setup_rollout_logging(config) + else: + self.log_rollouts = False + + # Setup the rollout store + # Rollouts contain the prompt & response, log probs, values and rewards - from each rollout + self.store = MRTRolloutStorage(self.tokenizer.pad_token_id) + + # Create the rollout store dataloader (for batching up rollouts) + # TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future + rollout_loader: DataLoader = self.store.create_loader(self.config.train.batch_size, shuffle=True) + + # Prepare multi-GPU acceleration + self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare( + self.model, self.opt, self.scheduler, rollout_loader + ) + + self.store.clear_history() # Clear the rollout store + + # Setup a reference model when hydra heads are not used + if not hasattr(self.model, "frozen_head"): + self.ref_model = self.get_arch(self.config) + self.ref_model.to(self.accelerator.device) + self.ref_model.eval() + + # Create the parameters for the Hugging Face language model's generator + # method (that generates new tokens from a prompt). + # https://huggingface.co/docs/transformers/v4.25.1/en/main_classes/text_generation#transformers.GenerationMixin.generate + if config.model.model_arch_type == "seq2seq": + self.generate_kwargs = dict( + config.method.gen_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + if config.method.gen_experience_kwargs is not None: + self.generate_experience_kwargs = dict( + config.method.gen_experience_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + else: + self.generate_experience_kwargs = None + else: + self.generate_kwargs = dict( + config.method.gen_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + ) + if config.method.gen_experience_kwargs is not None: + self.generate_experience_kwargs = dict( + config.method.gen_experience_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + ) + else: + self.generate_experience_kwargs = None + + # Setup stats tracker + self.running_moments = RunningMoments() + self.ref_mean = self.config.method.ref_mean + self.ref_std = self.config.method.ref_std + + def get_arch(self, config: TRLConfig): + """Get the model""" + model_class = AutoModelForCausalLMWithHydraValueHead + if config.model.model_arch_type == "seq2seq": + model_class = AutoModelForSeq2SeqLMWithHydraValueHead + + from_fn = model_class.from_pretrained + # backward-compat: Try to create a randomly initialized architecture from a config + if issubclass(type(config.model.model_path), transformers.PretrainedConfig): + from_fn = model_class.from_config + + return from_fn( + config.model.model_path, + num_layers_unfrozen=config.model.num_layers_unfrozen, + ) + + def loss(self, batch: MRTRLBatch): + """Forward pass & loss + + Args: + batch: Previous batch of episodes + """ + # Move `batch` data to `accelerator` device + query_tensors = batch.query_tensors.to(self.accelerator.device) + response_tensors = batch.response_tensors.to(self.accelerator.device) + logprobs = batch.logprobs.to(self.accelerator.device) + rewards = batch.rewards.to(self.accelerator.device) + + # remove middle dimension + batch_size = len(query_tensors) + num_candidates = self.config.method.num_candidates + query_tensors = query_tensors.reshape(batch_size * num_candidates, -1) + response_tensors = response_tensors.reshape(batch_size * num_candidates, -1) + logprobs = logprobs.reshape(batch_size * num_candidates, -1) + rewards = rewards.reshape(batch_size * num_candidates, -1) + response_length = rewards.shape[-1] + + if self.config.model.model_arch_type == "seq2seq": + input_ids = query_tensors + decoder_input_ids = response_tensors + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + decoder_attention_mask = ( + decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + ) + decoder_attention_mask[:, 0] = 1 + + # Forward pass + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + + logits = outputs.logits + values_pred = outputs.value + logprobs = logprobs_of_labels(logits[:, :-1, :], decoder_input_ids[:, 1:]) + mask = decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + start = 0 + end = start + response_length + logprobs, values_pred, mask = ( + logprobs[:, start:end], + values_pred[:, start:end], + mask[:, start:end], + ) + else: + tokens = torch.cat((query_tensors, response_tensors), dim=1) + attention_mask = tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device) + outputs = self.model(tokens, attention_mask, return_dict=True) + logits = outputs.logits + values_pred = outputs.value + values_pred = values_pred[:, :-1] + logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:]) + + start = query_tensors.shape[1] - 1 + end = start + response_length + logprobs, values_pred, mask = ( + logprobs[:, start:end], + values_pred[:, start:end], + attention_mask[:, start:end], + ) + + loss, stats = self.config.method.loss( + logprobs=logprobs, + rewards=rewards, + mask=mask, + ) + + return loss, stats + + def setup_rollout_logging(self, config): + # Make rollout logging dir for this run and store config + exists = os.path.exists(config.train.rollout_logging_dir) + isdir = os.path.isdir(config.train.rollout_logging_dir) + assert exists and isdir + + self.run_id = f"run-{uuid.uuid4()}" + self.rollout_logging_dir = os.path.join(config.train.rollout_logging_dir, self.run_id) + os.mkdir(self.rollout_logging_dir) + + with open(os.path.join(self.rollout_logging_dir, "config.json"), "w") as f: + f.write(json.dumps(config.to_dict(), indent=2)) + + def post_epoch_callback(self): + """Post epoch callback + + Clears the store and creates `num_rollouts` new episodes. + """ + if self.log_rollouts: + self.store.export_history(location=self.rollout_logging_dir) + self.store.clear_history() + # Collect more rollouts for training + self.make_experience(self.config.method.num_rollouts, self.iter_count) + + def post_backward_callback(self): + pass + + def prepare_learning(self): + eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) + self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader) + self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=True) + + self.n_updates_per_batch = 1 + self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) + self.total_steps = min(self.total_steps, self.config.train.total_steps) + + def add_prompt_pipeline(self, pipeline: PromptPipeline): + """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" + prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) + self.prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) + self.prompt_iterator = iter(self.prompt_dataloader) + + def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa: + """Make experiences + + Takes `chunk_size` number of prompts from `prompt_iterator`, samples + from the model and then computes the KL against a reference model. Finally it + then appends MRTRLElements to trainer's `store`. + + Args: + num_rollouts: Number of rollouts to generate + iter_count: Total number of updates run (i.e. number of updates run for all batches & epochs) + """ + logger.info("Collecting rollouts") + tbar = logging.tqdm( + total=num_rollouts, + disable=os.environ.get("RANK", 0) != "0", + desc=f"[rollout 0 / {num_rollouts}]", + # Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress + # bars (e.g. loss progress in trainers) + position=logging.get_verbosity() >= logging.WARNING, + # Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels + leave=logging.get_verbosity() < logging.WARNING, + ) + + num_candidates = self.config.method.num_candidates + + mrt_rl_elements = [] + stats = {} + clock = Clock() + + while len(mrt_rl_elements) * num_candidates < num_rollouts: + # Get next batch in prompt dataset and refresh if exhausted + # TOOD (jon-tow): Make `prompt_dataloader` a cyclic/infinite DataLoader to not require manually + # "refreshing" the contents of the `prompt_iterator` + try: + batch: PromptBatch = next(self.prompt_iterator) + except StopIteration: + self.prompt_iterator = iter(self.prompt_dataloader) + batch = next(self.prompt_iterator) + + exp_generate_time = time() + + # Generate samples from the language model (similar to using HuggingFace `generate` method) + # For MRT, this should generate num_candidates samples for each prompt in the batch + # So in total: [batch_size * num_candidates, response_len] + samples = self.generate(**batch) + device = samples.device + + # Expand queries and mask + copied_idxs = torch.tensor( + [i for i in range(batch.input_ids.shape[0]) for _ in range(num_candidates)], device=device + ) + # TODO change this part over here + batch.input_ids = torch.index_select( + batch.input_ids, 0, copied_idxs + ) # [batch_size, candidate_size, query_length] + batch.attention_mask = torch.index_select( + batch.attention_mask, 0, copied_idxs + ) # [batch_size, candidate_size, query_length] + + stats["time/exp_generate"] = time() - exp_generate_time + + prompt_tensors = batch.input_ids + + prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) + padded_samples = self.accelerator.pad_across_processes( + samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False + ) + padded_prompts = self.accelerator.pad_across_processes( + prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False + ) + gathered_samples = self.accelerator.gather(padded_samples) + gathered_prompts = self.accelerator.gather(padded_prompts) + gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) + + if self.accelerator.is_main_process: + all_str_samples, all_str_prompts, all_str_outputs = self.decode( + gathered_prompts, gathered_samples, gathered_prompt_sizes + ) + + exp_score_time = time() + all_scores = torch.tensor( + self.reward_fn( + samples=all_str_samples, + prompts=all_str_prompts, + outputs=all_str_outputs, + ), + dtype=torch.float, + device=device, + ) + stats["time/exp_score"] = time() - exp_score_time + + all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1).unbind()) + else: + all_scores = None + + if torch.distributed.is_initialized(): + scores = torch.empty(len(samples), device=device) + torch.distributed.scatter(scores, all_scores) + else: + scores = all_scores[0].clone() # torch.tensor(all_scores[0]) + + str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples) + + # Pad the sample outputs + outputs = self.tokenizer(str_outputs).input_ids + if self.config.model.model_arch_type == "seq2seq": + # add to the start of the output + for i in range(len(outputs)): + outputs[i] = [self.tokenizer.pad_token_id] + outputs[i] + + outputs = list(map(torch.LongTensor, outputs)) + maxsize = max(map(len, outputs)) + outputs = [ + F.pad( + output, + (0, maxsize - len(output)), + value=self.tokenizer.pad_token_id, + ) + for output in outputs + ] + sample_outputs = torch.vstack(outputs).to(device) + + # store statistics of the initial rollout as reference + if self.ref_mean is None: + self.ref_mean, self.ref_std = scores.mean(), scores.std() + all_scores_mean, all_scores_std = self.running_moments.update(scores) + stats["exp_scores/mean"] = all_scores_mean + stats["exp_scores/std"] = all_scores_std + stats["exp_scores/running_mean"] = self.running_moments.mean + stats["exp_scores/running_std"] = self.running_moments.std + + if self.config.method.scale_reward == "running": + scores /= self.running_moments.std + elif self.config.method.scale_reward == "ref": + scores /= self.ref_std + + clip_reward = self.config.method.cliprange_reward + if clip_reward: + scores = torch.clip(scores, -clip_reward, clip_reward) + + # Precompute logprobs, values + if self.config.model.model_arch_type == "seq2seq": + attention_mask = batch.attention_mask.to(device) + prompt_tensors = batch.input_ids.to(device) + decoder_attention_mask = sample_outputs.not_equal(self.tokenizer.pad_token_id) + decoder_attention_mask[:, 0] = 1 + with torch.no_grad(): + outputs = self.model( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + decoder_attention_mask=decoder_attention_mask, + ) + logits = outputs.logits + values = outputs.value + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + decoder_attention_mask=decoder_attention_mask, + return_dict=True, + ).logits + else: + ref_logits = self.ref_model( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + decoder_attention_mask=decoder_attention_mask, + return_dict=True, + ).logits + else: + all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) + attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) + with torch.no_grad(): + logits, *_, values = self.model( + all_tokens, + attention_mask=attention_mask, + ) + # TODO(dahoas): When hydra model works need to also support generation on hydra head + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + all_tokens, + attention_mask=attention_mask, + return_dict=True, + ).logits + else: + ref_logits = self.ref_model( + all_tokens, + attention_mask=attention_mask, + return_dict=True, + ).logits + ref_logits = ref_logits.to(device) + + if self.config.model.model_arch_type == "seq2seq": + logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) + else: + logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) + + n_samples: int = samples.shape[0] + logprobs = logprobs.cpu() + ref_logprobs = ref_logprobs.cpu() + prompt_tensors = prompt_tensors.cpu() + sample_outputs = sample_outputs.cpu() + + # Estimate the KL divergence between the model and reference model + if self.config.model.model_arch_type == "seq2seq": + values = values.cpu()[:, :-1] + start = 0 + + # Get the number of non-padding tokens for each sample + # This assumes all padding is on the right side + padding_token: int = 0 + ends = (sample_outputs[:, start:] != padding_token).sum(1) + + # Else if not seq2seq (i.e. causal) + else: + values = values.cpu()[:, :-1] + start = prompt_tensors.shape[1] - 1 + ends = start + attention_mask[:, start:].sum(1) + # all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] + # all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] + + # kl_divergence_estimate = 1.0 * (logprobs - ref_logprobs) + # kl_divergence_estimate = [rs[start : ends[ix]] for ix, rs in enumerate(kl_divergence_estimate)] + + rollout_count = 0 + + rewards = torch.zeros_like(logprobs, dtype=torch.float32) + rewards[torch.arange(len(rewards)), ends - 1] = scores.cpu() + + for idx in range(n_samples // num_candidates): + sample_idxs = torch.arange(idx * num_candidates, (idx + 1) * num_candidates) + + mrt_rl_elements.append( + MRTRLElement( + query_tensor=prompt_tensors[sample_idxs].view(num_candidates, -1), + response_tensor=sample_outputs[sample_idxs].view(num_candidates, -1), + logprobs=logprobs[sample_idxs].view(num_candidates, -1), + values=values[sample_idxs].view(num_candidates, -1), + rewards=rewards[sample_idxs].view(num_candidates, -1), + ) + ) + + rollout_count += num_candidates + exp_time = clock.tick() + tbar.set_description(f"[rollout {num_candidates * len(mrt_rl_elements)} / {num_rollouts}]") + tbar.update(min(rollout_count, num_rollouts)) + tbar.close() + + stats["time/exp"] = exp_time + + if not ray.is_initialized(): + self.accelerator.log(stats, step=iter_count) + + # Push samples and rewards to trainer's rollout storage + self.push_to_store(mrt_rl_elements) diff --git a/trlx/utils/loading.py b/trlx/utils/loading.py index 8f1722aed..4249afdac 100644 --- a/trlx/utils/loading.py +++ b/trlx/utils/loading.py @@ -7,6 +7,7 @@ # Register load trainers via module import from trlx.trainer import _TRAINERS, register_trainer from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer +from trlx.trainer.accelerate_mrt_trainer import AccelerateMRTTrainer from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer