diff --git a/ding/bonus/config.py b/ding/bonus/config.py index 285eff6586..ddb73fb235 100644 --- a/ding/bonus/config.py +++ b/ding/bonus/config.py @@ -167,6 +167,13 @@ def get_instance_config(env_id: str, algorithm: str) -> EasyDict: cfg.batch_size = 320 cfg.epoch_per_collect = 10 cfg.learning_rate = 3e-4 + elif env_id == 'chat': + cfg.epoch_per_collect = 1 + cfg.batch_size = 1 + cfg.learning_rate = 5e-7 + cfg.answers_per_question = 3 + cfg.kl_penalty_weight = 0.1 + cfg.ppo_param_init = False else: raise KeyError("not supported env type: {}".format(env_id)) else: @@ -315,6 +322,16 @@ def get_instance_env(env_id: str) -> BaseEnv: ) cfg = EasyDict(cfg) return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg)) + elif env_id == 'chat': + from dizoo.chat.env import ChatEnv + return ChatEnv( + batch_size=1, + reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover", + tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en", + data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data", + maxlen_prompt=128, + maxlen_res=128, + ) else: raise KeyError("not supported env type: {}".format(env_id)) diff --git a/ding/bonus/ppof.py b/ding/bonus/ppof.py index 88d0b43e1e..141f0d3b4d 100644 --- a/ding/bonus/ppof.py +++ b/ding/bonus/ppof.py @@ -1,3 +1,4 @@ +import copy from typing import Optional, Union, List from ditk import logging from easydict import EasyDict @@ -9,7 +10,7 @@ import torch from ding.framework import task, OnlineRLContext from ding.framework.middleware import interaction_evaluator_ttorch, PPOFStepCollector, multistep_trainer, CkptSaver, \ - wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator + wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator, ChatCollector from ding.envs import BaseEnv, BaseEnvManagerV2, SubprocessEnvManagerV2 from ding.policy import PPOFPolicy, single_env_forward_wrapper_ttorch from ding.utils import set_pkg_seed @@ -62,6 +63,8 @@ class PPOF: 'Hopper-v3', 'HalfCheetah-v3', 'Walker2d-v3', + # rlhf + 'chat' ] """ Overview: @@ -170,6 +173,8 @@ def __init__( action_shape = int(action_space.n) elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)): action_shape = get_hybrid_shape(action_space) + elif action_space is None: + pass else: action_shape = action_space.shape @@ -191,7 +196,11 @@ def __init__( popart_head=True, **self.cfg.model ) - self.policy = PPOFPolicy(self.cfg, model=model) + if self.cfg.chat_data: + orig_model = copy.deepcopy(model) + else: + orig_model = None + self.policy = PPOFPolicy(self.cfg, model=model, orig_model=orig_model) if policy_state_dict is not None: self.policy.load_state_dict(policy_state_dict) self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt") @@ -246,10 +255,14 @@ def train( pass with task.start(ctx=OnlineRLContext()): - task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env)) - task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt)) - task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample)) - task.use(ppof_adv_estimator(self.policy)) + if self.policy._cfg.chat_data: + # task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env)) + # task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt)) + task.use(ChatCollector(self.seed, self.policy, collector_env, self.cfg.n_sample)) + else: + task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env)) + task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt)) + task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample)) task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show)) task.use( wandb_online_logger( diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py index b9e3c5005d..43cea6883b 100644 --- a/ding/framework/middleware/__init__.py +++ b/ding/framework/middleware/__init__.py @@ -1,5 +1,5 @@ from .functional import * -from .collector import StepCollector, EpisodeCollector, PPOFStepCollector +from .collector import StepCollector, EpisodeCollector, PPOFStepCollector, ChatCollector from .learner import OffPolicyLearner, HERLearner from .ckpt_handler import CkptSaver from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index beb4894ad9..6017940396 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -1,3 +1,4 @@ +import copy from typing import TYPE_CHECKING from easydict import EasyDict import treetensor.torch as ttorch @@ -190,4 +191,64 @@ def __call__(self, ctx: "OnlineRLContext") -> None: break +class ChatCollector: + """ + Overview: + The class of the collector running by steps, including model inference and transition \ + process. Use the `__call__` method to execute the whole collection process. + """ + + def __new__(cls, *args, **kwargs): + if task.router.is_active and not task.has_role(task.role.COLLECTOR): + return task.void() + return super(ChatCollector, cls).__new__(cls) + + def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None: + """ + Arguments: + - seed (:obj:`int`): Random seed. + - policy (:obj:`Policy`): The policy to be collected. + - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ + its derivatives are supported. + """ + self.env = env + self.env.seed(seed) + self.env.launch() + self.env = self.env._envs[0] + self.policy = policy + self.n_sample = n_sample + self.unroll_len = unroll_len + + def __call__(self, ctx: "OnlineRLContext") -> None: + """ + Overview: + An encapsulation of inference and rollout middleware. Stop when completing \ + the target number of steps. + Input of ctx: + - env_step (:obj:`int`): The env steps which will increase during collection. + """ + device = self.policy._device + + obs = ttorch.as_tensor(self.env.last_batch['text_vec']) + batch_size = obs.shape[0] + obs = obs.to(device) + + total_action = [[] for _ in range(batch_size)] # [B, answers_per_question, T] + for _ in range(self.policy._cfg.answers_per_question): + _, inference_output = self.policy._model.actor.generate(obs, **ctx.collect_kwargs) + for i in range(batch_size): + total_action[i].append(copy.deepcopy(inference_output[i])) + + mask, resp, rew = self.env.step(total_action) + ctx.env_step += 1 + ctx.env_episode += 1 + + train_data = {} + train_data['obs'] = resp # [B x answer-per-question, T] + train_data['reward'] = rew # [B x answer-per-question, ] + train_data['mask'] = mask # [B x answer-per-question, T] + + ctx.train_data = ttorch.as_tensor(train_data) + + # TODO battle collector diff --git a/ding/model/common/__init__.py b/ding/model/common/__init__.py index 4bf7d8be5a..5bf1fba4d5 100755 --- a/ding/model/common/__init__.py +++ b/ding/model/common/__init__.py @@ -2,4 +2,4 @@ QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \ independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder -from .utils import create_model +from .utils import create_model, top_p_logits diff --git a/ding/model/common/tests/test_utils.py b/ding/model/common/tests/test_utils.py new file mode 100644 index 0000000000..9a6f688e2d --- /dev/null +++ b/ding/model/common/tests/test_utils.py @@ -0,0 +1,15 @@ +import pytest +import torch +from ding.model.common.utils import top_p_logits + + +@pytest.mark.unittest +class TestUtils: + + def test_top_p_logits(self): + test_logit = torch.Tensor([[0., 0.91, 0.05, 0.04], [0.04, 0.46, 0.46, 0.04]]) + + gt_logit = torch.Tensor([[0., 1., 0., 0.], [0., 0.5, 0.5, 0.]]) + + pred_logit = top_p_logits(test_logit) + assert torch.sum((gt_logit - pred_logit) ** 2).item() < 1e-8 diff --git a/ding/model/common/utils.py b/ding/model/common/utils.py index f74a179962..5208cf342c 100644 --- a/ding/model/common/utils.py +++ b/ding/model/common/utils.py @@ -29,3 +29,29 @@ def create_model(cfg: EasyDict) -> torch.nn.Module: import_module(cfg.pop('import_names', [])) # here we must use the pop opeartion to ensure compatibility return MODEL_REGISTRY.build(cfg.pop("type"), **cfg) + + +def top_p_logits(logits: torch.Tensor, topp: float = 0.9, filter_value: float = 0, min_topk: int = 1): + """ + Overview: + Filter a distribution of logits using nucleus (top-p) filtering. The output is also logit tensors but some \ + values are masked. + Arguments: + - logits (:obj:`torch.Tensor`): The input logits for top-p sampling. + - topp (:obj:`float`): The top-p value, such as 0.9. + - filter_value (:obj:`float`): The value for masked logits in output, default as 0. + - min_topk (:obj:`int`): The min number of sampled logit, default as 1 (which means that at least one sample \ + will not be masked.) + Returns: + - cum_logits (:obj:`torch.Tensor`): The output logits after masking. + """ + cum_logits = logits.clone() + if topp > 0: + logits_sorted, inds = torch.sort(logits, dim=-1, descending=True) + mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp + mask[..., :min_topk] = False + # Remove tokens with cumulative top_p above the threshold + mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask) + cum_logits[mask] = filter_value + cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True)) + return cum_logits diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index 4a63c3dcc6..ec39522608 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -5,6 +5,7 @@ from .vac import VAC, DREAMERVAC from .bc import DiscreteBC, ContinuousBC from .language_transformer import LanguageTransformer +from .lm_vac import LlamaVAC # algorithm-specific from .pg import PG from .ppg import PPG diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py new file mode 100644 index 0000000000..747b1d1d8f --- /dev/null +++ b/ding/model/template/lm_vac.py @@ -0,0 +1,202 @@ +from typing import Dict +import torch +import torch.nn as nn +try: + from transformers import LlamaTokenizer + from transformers.models.llama.modeling_llama import LlamaForCausalLM +except ImportError: + from ditk import logging + logging.warning("Not found transformer, please install it using: pip install transformers") + +from ding.model.common import top_p_logits +from ding.reward_model import LlamaRewardModel +from ding.utils import MODEL_REGISTRY + + +def get_tokenizer(path: str): + """ + Overview: + Return the pretrained tokenizer using the given path. + """ + tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True) + tokenizer.bos_token = '' + tokenizer.eos_token = '' + tokenizer.pad_token = '' + tokenizer.pad_token_id = 0 + tokenizer.unk_token = tokenizer.pad_token + tokenizer.unk_token_id = tokenizer.pad_token_id + + return tokenizer + + +class Llama(LlamaForCausalLM): + + def __init__(self, config, opt, tokenizer, enable_checkpointing): + super().__init__(config) + self.opt = opt + self.tokenizer = tokenizer + self.enable_checkpointing = enable_checkpointing + + def forward(self, decoder_input, incr_state=None, is_train=True): + + attention_mask = decoder_input.ne(self.tokenizer.pad_token_id) + if incr_state is not None: + decoder_input = decoder_input[:, -1:] + + output = super().forward( + input_ids=decoder_input, + attention_mask=attention_mask, + past_key_values=incr_state, + return_dict=True, + use_cache=not is_train + ) + + logits = output.logits + new_incr_states = output.past_key_values + + return logits, new_incr_states + + @torch.no_grad() + def generate(self, batch, **kwargs): + """ + Generate response + """ + if self.enable_checkpointing: + self.gradient_checkpointing_disable() + maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res) + temperature = kwargs.pop('temperature', self.opt.temperature) + repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty) + topp = kwargs.pop('topp', self.opt.topp) + + decoder_input: torch.LongTensor = batch # (bsz, ...) + assert decoder_input[:, -1].ne( + self.tokenizer.pad_token_id + ).all(), 'Last token should not be a padding token (you can use left padding instead).' + + dev = decoder_input.device + bsz = decoder_input.size(0) + + scores = torch.zeros((bsz, ), device=dev, dtype=torch.float16) + done = torch.zeros((bsz, ), device=dev).to(torch.bool) + + inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1) + decoder_input = torch.index_select(decoder_input, 0, inds) + init_length = decoder_input.size(1) + + incr_state = None + for _token in range(maxlen_res): + if done.all(): + break + score, incr_state, *_ = self.forward(decoder_input, incr_state, is_train=False) + score = score.half() + + # now score is bs, len, vocab_size + score = score[:, -1, :] + + # calculate repetition penalty + if repetition_penalty > 1.: + penalty_tokens = decoder_input[:, init_length:] + penalty_scores = torch.gather(score, dim=1, index=penalty_tokens) + penalty_scores = torch.where( + penalty_scores < 0., penalty_scores * repetition_penalty, penalty_scores / repetition_penalty + ) + score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores) + + # nucleus sampling + score = torch.softmax(score.div(temperature), dim=-1) + probs = top_p_logits(score, topp=topp, filter_value=0) + tok_ids = torch.multinomial(probs, 1)[:, 0] + hyp_ids = torch.arange(probs.size(0), device=dev) + scores = scores + probs[hyp_ids, tok_ids].log() * ~done + + tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids) + decoder_input = torch.cat((decoder_input, tok_ids.unsqueeze(-1)), dim=-1) + done = done | tok_ids.eq(self.tokenizer.eos_token_id) + + incr_state = self._reorder_cache(incr_state, hyp_ids) + + # get all finalized candidates for each sample + decoder_input = decoder_input[:, init_length:] + decoder_input = decoder_input.view(bsz, -1) + scores = scores.view(bsz, ) + + lengths = decoder_input.ne(self.tokenizer.pad_token_id).sum(dim=-1) + + length_penalty = torch.pow(lengths, 1.0) + scores /= length_penalty + + preds_scores = [] + for i in range(bsz): + seq: torch.LongTensor = decoder_input[i, :lengths[i, ]] + res_scores = (float(scores[i, ]), seq.tolist()) + preds_scores.append([res_scores]) + + best_preds_scores = [preds[0] for preds in preds_scores] + if self.enable_checkpointing: + self.gradient_checkpointing_enable() + return best_preds_scores, preds_scores + + +@MODEL_REGISTRY.register('llamavac') +class LlamaVAC(nn.Module): + """ + Overview: + The neural network and computation graph of Llama VAC. The actor and critic of this model are respectively \ + a Llama Pretrained Model. + Interfaces: + ``__init__``, ``forward``. + """ + mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] + + def __init__( + self, + actor_path: str, + critic_path: str, + tokenizer_path: str, + opt: Dict, + enable_checkpointing: bool = True + ) -> None: + """ + Overview: + Initialize the ``LlamaVAC`` model according to arguments. + Arguments: + - actor_path (:obj:`str`): Pretrained model path for actor. + - critic_path (:obj:`str`): Pretrained model path for critic. + - opt (:obj:`Dict`): Options for this model. + """ + super(LlamaVAC, self).__init__() + tokenizer = get_tokenizer(tokenizer_path) + self.enable_checkpointing = enable_checkpointing + + self.actor = Llama.from_pretrained( + actor_path, + opt=opt, + tokenizer=tokenizer, + torch_dtype=torch.bfloat16, + enable_checkpointing=enable_checkpointing + ) + + self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer, torch_dtype=torch.bfloat16) + + if enable_checkpointing: + self.actor.gradient_checkpointing_enable() + self.critic.gradient_checkpointing_enable() + + def forward(self, x: torch.Tensor, mode: str) -> Dict: + assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) + return getattr(self, mode)(x) + + def compute_actor(self, x): + policy_output = self.actor(decoder_input=x) + policy_logit, *_ = policy_output + return {"logit": policy_logit} + + def compute_critic(self, x): + values = self.critic(decoder_input=x, only_last=False) + return {"value": values} + + def compute_actor_critic(self, x): + policy_output = self.actor(decoder_input=x) + policy_logit, *_ = policy_output + values = self.critic(decoder_input=x, only_last=False) + return {"logit": policy_logit, "value": values} diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py index 81e605384c..3dbfb05581 100644 --- a/ding/policy/ppof.py +++ b/ding/policy/ppof.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any, Tuple, Union, Callable, Optional +from typing import List, Dict, Any, Callable, Optional from collections import namedtuple from easydict import EasyDict import copy @@ -7,10 +7,13 @@ import torch import treetensor.torch as ttorch from torch.optim import AdamW +from torch.cuda.amp import GradScaler +from torch import autocast from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, gae, gae_data, ppo_error_continuous, \ get_gae, ppo_policy_error_continuous, ArgmaxSampler, MultinomialSampler, ReparameterizationSampler, MuSampler, \ HybridStochasticSampler, HybridDeterminsticSampler, value_transform, value_inv_transform, symlog, inv_symlog +from ding.rl_utils.gae import episodic_gae_data, episodic_gae from ding.utils import POLICY_REGISTRY, RunningMeanStd @@ -37,6 +40,7 @@ class PPOFPolicy: value_norm='baseline', ppo_param_init=True, grad_norm=0.5, + chat_data=True, # collect n_sample=128, unroll_len=1, @@ -58,8 +62,15 @@ def default_model(cls: type) -> Callable: from .model import PPOFModel return PPOFModel - def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None) -> None: + def __init__( + self, + cfg: "EasyDict", + model: torch.nn.Module, + enable_mode: List[str] = None, + orig_model: torch.nn.Module = None + ) -> None: self._cfg = cfg + self._orig_model = orig_model if model is None: self._model = self.default_model() else: @@ -151,69 +162,90 @@ def _model_param_init(self): def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: return_infos = [] self._model.train() - bs = self._cfg.batch_size - data = data[:self._cfg.n_sample // bs * bs] # rounding + if not self._cfg.chat_data: + bs = self._cfg.batch_size + data = data[:self._cfg.n_sample // bs * bs] # rounding # outer training loop for epoch in range(self._cfg.epoch_per_collect): # recompute adv with torch.no_grad(): - # get the value dictionary - # In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred' - value = self._model.compute_critic(data.obs) - next_value = self._model.compute_critic(data.next_obs) - reward = data.reward - - assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'],\ - 'Not supported value normalization! Value normalization supported: \ - popart, value rescale, symlog, baseline' - - if self._cfg.value_norm == 'popart': - unnormalized_value = value['unnormalized_pred'] - unnormalized_next_value = value['unnormalized_pred'] - - mu = self._model.critic_head.popart.mu - sigma = self._model.critic_head.popart.sigma - reward = (reward - mu) / sigma - - value = value['pred'] - next_value = next_value['pred'] - elif self._cfg.value_norm == 'value_rescale': - value = value_inv_transform(value['pred']) - next_value = value_inv_transform(next_value['pred']) - elif self._cfg.value_norm == 'symlog': - value = inv_symlog(value['pred']) - next_value = inv_symlog(next_value['pred']) - elif self._cfg.value_norm == 'baseline': - value = value['pred'] * self._running_mean_std.std - next_value = next_value['pred'] * self._running_mean_std.std - - traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory - adv_data = gae_data(value, next_value, reward, data.done, traj_flag) - data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda) - - unnormalized_returns = value + data.adv # In popart, this return is normalized - - if self._cfg.value_norm == 'popart': - self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1)) - elif self._cfg.value_norm == 'value_rescale': - value = value_transform(value) - unnormalized_returns = value_transform(unnormalized_returns) - elif self._cfg.value_norm == 'symlog': - value = symlog(value) - unnormalized_returns = symlog(unnormalized_returns) - elif self._cfg.value_norm == 'baseline': - value /= self._running_mean_std.std - unnormalized_returns /= self._running_mean_std.std - self._running_mean_std.update(unnormalized_returns.cpu().numpy()) - data.value = value - data.return_ = unnormalized_returns + if self._cfg.chat_data: + # [B, T] + value = self._model.compute_critic(data.obs)['value'] + self._model.cpu() + self._orig_model.cuda() + data.orig_logit = self._orig_model.compute_actor(data.obs)['logit'] + self._orig_model.cpu() + self._model.cuda() + data.value = value + reward = data.reward + + traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory + done = data.get('done', None) + adv_data = episodic_gae_data(value, data.mask, reward, done, traj_flag) + data.adv = episodic_gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda) + + unnormalized_returns = data.value + data.adv + data.return_ = unnormalized_returns + else: + # get the value dictionary + # In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred' + value = self._model.compute_critic(data.obs) + next_value = self._model.compute_critic(data.next_obs) + reward = data.reward + + assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'], \ + 'Not supported value normalization! Value normalization supported: \ + popart, value rescale, symlog, baseline' + + if self._cfg.value_norm == 'popart': + unnormalized_value = value['unnormalized_pred'] + unnormalized_next_value = value['unnormalized_pred'] + + mu = self._model.critic_head.popart.mu + sigma = self._model.critic_head.popart.sigma + reward = (reward - mu) / sigma + + value = value['pred'] + next_value = next_value['pred'] + elif self._cfg.value_norm == 'value_rescale': + value = value_inv_transform(value['pred']) + next_value = value_inv_transform(next_value['pred']) + elif self._cfg.value_norm == 'symlog': + value = inv_symlog(value['pred']) + next_value = inv_symlog(next_value['pred']) + elif self._cfg.value_norm == 'baseline': + value = value['pred'] * self._running_mean_std.std + next_value = next_value['pred'] * self._running_mean_std.std + + traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory + adv_data = gae_data(value, next_value, reward, data.done, traj_flag) + data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda) + + unnormalized_returns = value + data.adv # In popart, this return is normalized + + if self._cfg.value_norm == 'popart': + self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1)) + elif self._cfg.value_norm == 'value_rescale': + value = value_transform(value) + unnormalized_returns = value_transform(unnormalized_returns) + elif self._cfg.value_norm == 'symlog': + value = symlog(value) + unnormalized_returns = symlog(unnormalized_returns) + elif self._cfg.value_norm == 'baseline': + value /= self._running_mean_std.std + unnormalized_returns /= self._running_mean_std.std + self._running_mean_std.update(unnormalized_returns.cpu().numpy()) + data.value = value + data.return_ = unnormalized_returns # inner training loop split_data = ttorch.split(data, self._cfg.batch_size) random.shuffle(list(split_data)) for batch in split_data: - output = self._model.compute_actor_critic(batch.obs) + if not self._cfg.chat_data: + output = self._model.compute_actor_critic(batch.obs) adv = batch.adv if self._cfg.adv_norm: # Normalize advantage in a train_batch @@ -226,10 +258,27 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: ) ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._cfg.clip_ratio) elif self._action_space == 'discrete': - ppo_batch = ppo_data( - output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None - ) - ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) + if not self._cfg.chat_data: + ppo_batch = ppo_data( + output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, mask + ) + ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) + else: + with autocast(device_type='cuda', dtype=torch.float16): + output = self._model.compute_actor_critic(batch.obs) + mask = batch.mask + ppo_batch = ppo_data( + output['logit'], batch.orig_logit, batch.obs, output['value'], batch.value, adv, + batch.return_, None + ) + ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) + kl_loss = ( + torch.nn.functional.kl_div( + torch.softmax(output["logit"], dim=-1), + torch.softmax(batch.orig_logit, dim=-1), + reduction='none' + ) * mask.unsqueeze(-1) + ).mean() elif self._action_space == 'hybrid': # discrete part (discrete policy loss and entropy loss) ppo_discrete_batch = ppo_policy_data( @@ -253,13 +302,20 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl), max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac) ) - wv, we = self._cfg.value_weight, self._cfg.entropy_weight - total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss - - self._optimizer.zero_grad() - total_loss.backward() - torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm) - self._optimizer.step() + if not self._cfg.chat_data: + wv, we = self._cfg.value_weight, self._cfg.entropy_weight + total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + self._optimizer.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm) + self._optimizer.step() + else: + wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight + total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss + output = ttorch.as_tensor(output) + self._optimizer.zero_grad() + total_loss.backward() + self._optimizer.step() return_info = { 'cur_lr': self._optimizer.defaults['lr'], @@ -267,6 +323,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: 'policy_loss': ppo_loss.policy_loss.item(), 'value_loss': ppo_loss.value_loss.item(), 'entropy_loss': ppo_loss.entropy_loss.item(), + 'kl_loss': kl_loss.item(), 'adv_max': adv.max().item(), 'adv_mean': adv.mean().item(), 'value_mean': output.value.mean().item(), diff --git a/ding/reward_model/__init__.py b/ding/reward_model/__init__.py index 4538102861..5b197af25d 100644 --- a/ding/reward_model/__init__.py +++ b/ding/reward_model/__init__.py @@ -13,3 +13,5 @@ from .guided_cost_reward_model import GuidedCostRewardModel from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel from .icm_reward_model import ICMRewardModel +# RLHF +from .language_reward_model import LlamaRewardModel diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py new file mode 100644 index 0000000000..ff2558099a --- /dev/null +++ b/ding/reward_model/language_reward_model.py @@ -0,0 +1,28 @@ +import torch +try: + from transformers.models.llama.modeling_llama import LlamaForCausalLM +except ImportError: + from ditk import logging + logging.warning("Not found transformer, please install it using: pip install transformers") + + +class LlamaRewardModel(LlamaForCausalLM): + + def __init__(self, config, tokenizer): + super().__init__(config) + self.tokenizer = tokenizer + self.reward_head = torch.nn.Linear(config.hidden_size, 1, bias=False) + + def forward(self, decoder_input, only_last=True): + attention_mask = decoder_input.ne(self.tokenizer.pad_token_id) + with torch.no_grad(): + output = self.model.forward( + input_ids=decoder_input, attention_mask=attention_mask, return_dict=True, use_cache=False + ) + + if only_last: + logits = self.reward_head(output.last_hidden_state[:, -1, :]).squeeze(-1) + else: + logits = self.reward_head(output.last_hidden_state).squeeze(-1) + + return logits diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py index 800fcae354..16b8313750 100644 --- a/ding/rl_utils/gae.py +++ b/ding/rl_utils/gae.py @@ -3,6 +3,7 @@ from ding.hpc_rl import hpc_wrapper gae_data = namedtuple('gae_data', ['value', 'next_value', 'reward', 'done', 'traj_flag']) +episodic_gae_data = namedtuple('episodic_gae_data', ['value', 'mask', 'reward', 'done', 'traj_flag']) def shape_fn_gae(args, kwargs): @@ -68,3 +69,24 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F gae_item = delta[t] + factor[t] * gae_item adv[t] = gae_item return adv + + +def episodic_gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97): + value, mask, reward, done, traj_flag = data + if done is None: + done = torch.zeros_like(value) + if traj_flag is None: + traj_flag = done + advs = [] + bsz = value.shape[0] + for i in range(bsz): + val, mas, rew, don, traj = value[i], mask[i], reward[i], done[i], traj_flag[i] + next_val = torch.zeros_like(val) + next_val[:-1] = val[1:] + reward = torch.zeros_like(val) + reward[-1] = rew + gd = gae_data( + val.unsqueeze(-1), next_val.unsqueeze(-1), reward.unsqueeze(-1), don.unsqueeze(-1), traj.unsqueeze(-1) + ) + advs.append(gae(gd, gamma, lambda_).squeeze(-1)) + return torch.stack(advs, dim=0) diff --git a/dizoo/chat/__init__.py b/dizoo/chat/__init__.py new file mode 100644 index 0000000000..eb1eb48abb --- /dev/null +++ b/dizoo/chat/__init__.py @@ -0,0 +1 @@ +from .env import ChatEnv diff --git a/dizoo/chat/entry.py b/dizoo/chat/entry.py new file mode 100644 index 0000000000..5da6ac620d --- /dev/null +++ b/dizoo/chat/entry.py @@ -0,0 +1,34 @@ +from easydict import EasyDict + +from ding.bonus.ppof import PPOF +from ding.model.template import LlamaVAC + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--actor_path', type=str) + parser.add_argument('--critic_path', type=str) + parser.add_argument('--tokenizer_path', type=str) + args = parser.parse_args() + + opt = EasyDict({ + "maxlen_res": 512, + "temperature": 1, + "repetition_penalty": 1, + "topp": 0.8 + }) + + model = LlamaVAC( + actor_path=args.actor_path, + critic_path=args.critic_path, + tokenizer_path=args.tokenizer_path, + opt=opt + ) + + policy = PPOF( + env_id="chat", + exp_name="rlhf-ppo", + model=model + ) + policy.train(collector_env_num=1, evaluator_env_num=1, debug=True) diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py new file mode 100644 index 0000000000..01d0f183cb --- /dev/null +++ b/dizoo/chat/env.py @@ -0,0 +1,77 @@ +import torch + +from ding.envs import BaseEnv +from ding.reward_model import LlamaRewardModel +from .utils import OnlyPromptDataset, concat_context_and_response, get_tokenizer, pad_sequences + + +class ChatEnv(BaseEnv): + def __init__( + self, + batch_size: int, + reward_model_path: str, + tokenizer_path: str, + data_path: str, + maxlen_prompt: int, + maxlen_res: int, + ): + self.batch_size = batch_size + self.tokenizer = get_tokenizer(tokenizer_path) + self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer) + self.action_space = None + self.observation_space = None + self.reward_space = None + + self._init_flag = False + self._seed = None + + self.dataset = OnlyPromptDataset( + data_path=data_path, + tokenizer=self.tokenizer, + batch_size=batch_size, + maxlen_prompt=maxlen_prompt, + maxlen_res=maxlen_res, + mode='train', + ) + self.generator = self.dataset.final_generator() + self.last_batch = None + + def close(self) -> None: + self._init_flag = False + + def reset(self): + self.last_batch = next(self.generator) + if self.last_batch is None: + self.generator = self.dataset.final_generator() + self.last_batch = next(self.generator) + self._init_flag = True + return self.last_batch + + def __repr__(self) -> str: + return "DI-engine Chat Env" + + def seed(self, seed): + self._seed = seed + + def clone(self, caller): + # It should not create a new copy, since the language model is initialized. + return self + + def step(self, action): + """ + For each step, this env will return a batch of prompts. These prompts a vectorized by using tokenizer, and are \ + padded into the same length. + """ + output_mask, output_vec = concat_context_and_response(self.tokenizer, self.last_batch['text_vec'].tolist(), action) + output_vec = pad_sequences(output_vec, self.tokenizer.pad_token_id, padding='left') + rm_input = torch.tensor(output_vec, dtype=torch.long) + output_mask = pad_sequences(output_mask, self.tokenizer.pad_token_id, padding='left') + with torch.no_grad(): + rew = self.rm(rm_input) + + self.last_batch = next(self.generator) + if self.last_batch is None: + self.generator = self.dataset.final_generator() + self.last_batch = next(self.generator) + + return output_mask, output_vec, rew diff --git a/dizoo/chat/utils.py b/dizoo/chat/utils.py new file mode 100644 index 0000000000..be05e76805 --- /dev/null +++ b/dizoo/chat/utils.py @@ -0,0 +1,244 @@ +import json +import os +from typing import List, Dict, Any, Tuple +import warnings + +from transformers.models.llama.tokenization_llama import LlamaTokenizer +from torch.utils.data.dataset import IterableDataset +import torch +import random + + +# Prefix of human sentence and assistant sentence. +HUMAN_PROMPT = "Human:" +ASSISTANT_PROMPT = "Assistant:" + + +def strip_pad_token_id(tokenizer: LlamaTokenizer, seq: List[int]): + """ + Overview: + Remove ``pad_token_id`` in a sequence. + """ + return [tok for tok in seq if tok != tokenizer.pad_token_id] + + +def concat_context_and_response( + tokenizer: LlamaTokenizer, + context: List[List[int]], + responses: List[List[Tuple[float, List[int]]]] +): + """ + Overview: + Given the batched input prompts and responses, concatenate them together. + """ + assert len(context) == len(responses), f'Size not match: {len(context)} and {len(responses)}' + + total_context, total_response = [], [] + total_context_mask, total_response_mask = [], [] + for _context, _response in zip(context, responses): + # Each ``_context`` is a single input prompt. + _context = strip_pad_token_id(tokenizer, _context) + for resp in _response: + # Each ``resp`` is a single response. + resp = resp[0][1] + resp = strip_pad_token_id(tokenizer, resp) + if resp[-1] != tokenizer.eos_token_id: + warnings.warn( + f'Generated response is too long: {tokenizer.decode(_context + resp, skip_special_tokens=False)}') + + total_context.append(_context.copy()) + total_context_mask.append([0] * len(_context)) + total_response.append(resp) + total_response_mask.append([1] * len(resp)) + + total_gene_samples_vec = [c + r for c, r in zip(total_context, total_response)] + total_gene_samples_mask = [c + r for c, r in zip(total_context_mask, total_response_mask)] + return total_gene_samples_mask, total_gene_samples_vec + + +def pad_sequences( + seqs: List[List[int]], + pad_value: int, + padding: str = 'right'): + """ + Overview: + Padding sequence to the same length + """ + max_len = max(len(seq) for seq in seqs) + if padding == 'right': + padded_seqs = [seq + [pad_value] * (max_len - len(seq)) for seq in seqs] + elif padding == 'left': + padded_seqs = [[pad_value] * (max_len - len(seq)) + seq for seq in seqs] + else: + raise ValueError + return padded_seqs + + +def get_special_prompt(i: int): + return HUMAN_PROMPT if i % 2 == 0 else ASSISTANT_PROMPT + + +def get_model_prompt(context: List[str], eos_token=""): + human_prompt, assistant_prompt = HUMAN_PROMPT, ASSISTANT_PROMPT + if context[-1].startswith(human_prompt): + end_prompt = assistant_prompt + elif context[-1].startswith(assistant_prompt): + end_prompt = human_prompt + else: + raise ValueError + + context = eos_token.join(context) + return f'{context}{eos_token}{end_prompt}' + + +def get_tokenizer(path: str): + """ + Overview: + Return the pretrained tokenizer using the given path. + """ + tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True) + tokenizer.bos_token = '' + tokenizer.eos_token = '' + tokenizer.pad_token = '' + tokenizer.pad_token_id = 0 + tokenizer.unk_token = tokenizer.pad_token + tokenizer.unk_token_id = tokenizer.pad_token_id + + return tokenizer + + +class OnlyPromptDataset(IterableDataset): + """ + Overview: + Dataset that only contains the prompts of the raw data (no answer). + """ + def __init__( + self, + data_path: os.PathLike, + tokenizer, + batch_size: int, + maxlen_prompt: int, + maxlen_res: int, + mode: str = 'train', + ) -> None: + super().__init__() + self.mode = mode + self.tokenizer = tokenizer + self.maxlen_prompt = maxlen_prompt + self.maxlen_res = maxlen_res + self.batch_size = batch_size + + # Load data. + self.data = [] + files = sorted([file for file in os.listdir(data_path) if file.endswith(f'{mode}.json')]) + for file in files: + file_path = os.path.join(data_path, file) + tmp_data = [] + try: + tmp_data = self.load_data(file_path) + except Exception as e: + pass + self.data.extend(tmp_data) + + # Set the length of this dataset. + self.size = len(self.data) + + def __len__(self): + return self.size + + def load_data(self, file_path: str): + """ + Overview: + Load raw data from given file_path. + """ + with open(file_path, 'r') as f: + data: List[List[str]] = json.load(f) + + output: List[List[str]] = [sample for sample in data if all(sample)] + del data + + return output + + def final_generator(self): + data_generator = self.batch_generator() + for batch_samples in data_generator: + batch = self.batchify(batch_samples) + yield batch + + def __iter__(self): + return self.final_generator() + + def format(self, sample: List[str]) -> Dict[str, Any]: + """ + Overview: + Convert one data sample in to string. + """ + context = sample + context = [get_special_prompt(i + (len(context) + 1) % 2) + s for i, s in enumerate(context)] + context_vec = self.tokenizer.encode(get_model_prompt(context, self.tokenizer.eos_token), + add_special_tokens=True) + + # truncate to max_len + while len(context_vec) > self.maxlen_prompt - self.maxlen_res and len(context) > 1: + context = context[1:] + context_vec = self.tokenizer.encode(get_model_prompt(context, self.tokenizer.eos_token), + add_special_tokens=True) + + output = { + 'text': self.tokenizer.decode(context_vec, skip_special_tokens=False), + 'text_vec': context_vec + } + + return output + + def batchify(self, batch_samples: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Batchify a list of ids by padding their shape to be the same. + """ + batch_text_vec = torch.tensor(pad_sequences( + [sample['text_vec'] for sample in batch_samples], pad_value=self.tokenizer.pad_token_id, padding='left' + ), dtype=torch.long) + return { + 'text_vec': batch_text_vec, + 'text': [sample['text'] for sample in batch_samples] + } + + def sample_generator(self): + """ + Overview: + Generate a single data sample. + """ + random.seed(None) + if self.mode == 'train': + random.shuffle(self.data) + + for sample in self.data: + yield self.format(sample) + + def _batch_generator(self): + """ + Overview: + Generate a batch of samples. + """ + batch = [] + # Generate a sample. + for sample in self.sample_generator(): + sample_len = len(sample['text_vec']) + if sample_len > self.maxlen_prompt: + continue + + batch.append(sample) + if len(batch) >= self.batch_size: + yield batch[:self.batch_size] + batch = batch[self.batch_size:] + if batch: + yield batch + + def batch_generator(self): + while True: + for batch in self._batch_generator(): + if len(batch) == self.batch_size: + yield batch + if self.mode != 'train': + break diff --git a/launch_ppof.py b/launch_ppof.py new file mode 100644 index 0000000000..53825dcba2 --- /dev/null +++ b/launch_ppof.py @@ -0,0 +1,50 @@ +from easydict import EasyDict +from transformers import LlamaTokenizer + +from ding.bonus.ppof import PPOF +from ding.model.template.vac import LlamaVAC + + +def get_tokenizer(path: str): + """ + Overview: + Return the pretrained tokenizer using the given path. + """ + tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True) + tokenizer.bos_token = '' + tokenizer.eos_token = '' + tokenizer.pad_token = '' + tokenizer.pad_token_id = 0 + tokenizer.unk_token = tokenizer.pad_token + tokenizer.unk_token_id = tokenizer.pad_token_id + + return tokenizer + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--actor_path', type=str) + parser.add_argument('--critic_path', type=str) + args = parser.parse_args() + + opt = EasyDict({ + "maxlen_res": 512, + "temperature": 1, + "repetition_penalty": 1, + "topp": 0.8 + }) + + model = LlamaVAC( + actor_path=args.actor_path, + critic_path=args.critic_path, + tokenizer=get_tokenizer("/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-sft-model-7B-en"), + opt=opt + ) + + policy = PPOF( + env_id="chat", + exp_name="rlhf-ppo", + model=model + ) + policy.train(collector_env_num=1, evaluator_env_num=1, debug=True)