diff --git a/ding/bonus/config.py b/ding/bonus/config.py
index 285eff6586..ddb73fb235 100644
--- a/ding/bonus/config.py
+++ b/ding/bonus/config.py
@@ -167,6 +167,13 @@ def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
+ elif env_id == 'chat':
+ cfg.epoch_per_collect = 1
+ cfg.batch_size = 1
+ cfg.learning_rate = 5e-7
+ cfg.answers_per_question = 3
+ cfg.kl_penalty_weight = 0.1
+ cfg.ppo_param_init = False
else:
raise KeyError("not supported env type: {}".format(env_id))
else:
@@ -315,6 +322,16 @@ def get_instance_env(env_id: str) -> BaseEnv:
)
cfg = EasyDict(cfg)
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
+ elif env_id == 'chat':
+ from dizoo.chat.env import ChatEnv
+ return ChatEnv(
+ batch_size=1,
+ reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover",
+ tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en",
+ data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data",
+ maxlen_prompt=128,
+ maxlen_res=128,
+ )
else:
raise KeyError("not supported env type: {}".format(env_id))
diff --git a/ding/bonus/ppof.py b/ding/bonus/ppof.py
index 88d0b43e1e..141f0d3b4d 100644
--- a/ding/bonus/ppof.py
+++ b/ding/bonus/ppof.py
@@ -1,3 +1,4 @@
+import copy
from typing import Optional, Union, List
from ditk import logging
from easydict import EasyDict
@@ -9,7 +10,7 @@
import torch
from ding.framework import task, OnlineRLContext
from ding.framework.middleware import interaction_evaluator_ttorch, PPOFStepCollector, multistep_trainer, CkptSaver, \
- wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator
+ wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator, ChatCollector
from ding.envs import BaseEnv, BaseEnvManagerV2, SubprocessEnvManagerV2
from ding.policy import PPOFPolicy, single_env_forward_wrapper_ttorch
from ding.utils import set_pkg_seed
@@ -62,6 +63,8 @@ class PPOF:
'Hopper-v3',
'HalfCheetah-v3',
'Walker2d-v3',
+ # rlhf
+ 'chat'
]
"""
Overview:
@@ -170,6 +173,8 @@ def __init__(
action_shape = int(action_space.n)
elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)):
action_shape = get_hybrid_shape(action_space)
+ elif action_space is None:
+ pass
else:
action_shape = action_space.shape
@@ -191,7 +196,11 @@ def __init__(
popart_head=True,
**self.cfg.model
)
- self.policy = PPOFPolicy(self.cfg, model=model)
+ if self.cfg.chat_data:
+ orig_model = copy.deepcopy(model)
+ else:
+ orig_model = None
+ self.policy = PPOFPolicy(self.cfg, model=model, orig_model=orig_model)
if policy_state_dict is not None:
self.policy.load_state_dict(policy_state_dict)
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
@@ -246,10 +255,14 @@ def train(
pass
with task.start(ctx=OnlineRLContext()):
- task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
- task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
- task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
- task.use(ppof_adv_estimator(self.policy))
+ if self.policy._cfg.chat_data:
+ # task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
+ # task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(ChatCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
+ else:
+ task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
+ task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show))
task.use(
wandb_online_logger(
diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py
index b9e3c5005d..43cea6883b 100644
--- a/ding/framework/middleware/__init__.py
+++ b/ding/framework/middleware/__init__.py
@@ -1,5 +1,5 @@
from .functional import *
-from .collector import StepCollector, EpisodeCollector, PPOFStepCollector
+from .collector import StepCollector, EpisodeCollector, PPOFStepCollector, ChatCollector
from .learner import OffPolicyLearner, HERLearner
from .ckpt_handler import CkptSaver
from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py
index beb4894ad9..6017940396 100644
--- a/ding/framework/middleware/collector.py
+++ b/ding/framework/middleware/collector.py
@@ -1,3 +1,4 @@
+import copy
from typing import TYPE_CHECKING
from easydict import EasyDict
import treetensor.torch as ttorch
@@ -190,4 +191,64 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
break
+class ChatCollector:
+ """
+ Overview:
+ The class of the collector running by steps, including model inference and transition \
+ process. Use the `__call__` method to execute the whole collection process.
+ """
+
+ def __new__(cls, *args, **kwargs):
+ if task.router.is_active and not task.has_role(task.role.COLLECTOR):
+ return task.void()
+ return super(ChatCollector, cls).__new__(cls)
+
+ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None:
+ """
+ Arguments:
+ - seed (:obj:`int`): Random seed.
+ - policy (:obj:`Policy`): The policy to be collected.
+ - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
+ its derivatives are supported.
+ """
+ self.env = env
+ self.env.seed(seed)
+ self.env.launch()
+ self.env = self.env._envs[0]
+ self.policy = policy
+ self.n_sample = n_sample
+ self.unroll_len = unroll_len
+
+ def __call__(self, ctx: "OnlineRLContext") -> None:
+ """
+ Overview:
+ An encapsulation of inference and rollout middleware. Stop when completing \
+ the target number of steps.
+ Input of ctx:
+ - env_step (:obj:`int`): The env steps which will increase during collection.
+ """
+ device = self.policy._device
+
+ obs = ttorch.as_tensor(self.env.last_batch['text_vec'])
+ batch_size = obs.shape[0]
+ obs = obs.to(device)
+
+ total_action = [[] for _ in range(batch_size)] # [B, answers_per_question, T]
+ for _ in range(self.policy._cfg.answers_per_question):
+ _, inference_output = self.policy._model.actor.generate(obs, **ctx.collect_kwargs)
+ for i in range(batch_size):
+ total_action[i].append(copy.deepcopy(inference_output[i]))
+
+ mask, resp, rew = self.env.step(total_action)
+ ctx.env_step += 1
+ ctx.env_episode += 1
+
+ train_data = {}
+ train_data['obs'] = resp # [B x answer-per-question, T]
+ train_data['reward'] = rew # [B x answer-per-question, ]
+ train_data['mask'] = mask # [B x answer-per-question, T]
+
+ ctx.train_data = ttorch.as_tensor(train_data)
+
+
# TODO battle collector
diff --git a/ding/model/common/__init__.py b/ding/model/common/__init__.py
index 4bf7d8be5a..5bf1fba4d5 100755
--- a/ding/model/common/__init__.py
+++ b/ding/model/common/__init__.py
@@ -2,4 +2,4 @@
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
-from .utils import create_model
+from .utils import create_model, top_p_logits
diff --git a/ding/model/common/tests/test_utils.py b/ding/model/common/tests/test_utils.py
new file mode 100644
index 0000000000..9a6f688e2d
--- /dev/null
+++ b/ding/model/common/tests/test_utils.py
@@ -0,0 +1,15 @@
+import pytest
+import torch
+from ding.model.common.utils import top_p_logits
+
+
+@pytest.mark.unittest
+class TestUtils:
+
+ def test_top_p_logits(self):
+ test_logit = torch.Tensor([[0., 0.91, 0.05, 0.04], [0.04, 0.46, 0.46, 0.04]])
+
+ gt_logit = torch.Tensor([[0., 1., 0., 0.], [0., 0.5, 0.5, 0.]])
+
+ pred_logit = top_p_logits(test_logit)
+ assert torch.sum((gt_logit - pred_logit) ** 2).item() < 1e-8
diff --git a/ding/model/common/utils.py b/ding/model/common/utils.py
index f74a179962..5208cf342c 100644
--- a/ding/model/common/utils.py
+++ b/ding/model/common/utils.py
@@ -29,3 +29,29 @@ def create_model(cfg: EasyDict) -> torch.nn.Module:
import_module(cfg.pop('import_names', []))
# here we must use the pop opeartion to ensure compatibility
return MODEL_REGISTRY.build(cfg.pop("type"), **cfg)
+
+
+def top_p_logits(logits: torch.Tensor, topp: float = 0.9, filter_value: float = 0, min_topk: int = 1):
+ """
+ Overview:
+ Filter a distribution of logits using nucleus (top-p) filtering. The output is also logit tensors but some \
+ values are masked.
+ Arguments:
+ - logits (:obj:`torch.Tensor`): The input logits for top-p sampling.
+ - topp (:obj:`float`): The top-p value, such as 0.9.
+ - filter_value (:obj:`float`): The value for masked logits in output, default as 0.
+ - min_topk (:obj:`int`): The min number of sampled logit, default as 1 (which means that at least one sample \
+ will not be masked.)
+ Returns:
+ - cum_logits (:obj:`torch.Tensor`): The output logits after masking.
+ """
+ cum_logits = logits.clone()
+ if topp > 0:
+ logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
+ mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
+ mask[..., :min_topk] = False
+ # Remove tokens with cumulative top_p above the threshold
+ mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask)
+ cum_logits[mask] = filter_value
+ cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True))
+ return cum_logits
diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py
index 4a63c3dcc6..ec39522608 100755
--- a/ding/model/template/__init__.py
+++ b/ding/model/template/__init__.py
@@ -5,6 +5,7 @@
from .vac import VAC, DREAMERVAC
from .bc import DiscreteBC, ContinuousBC
from .language_transformer import LanguageTransformer
+from .lm_vac import LlamaVAC
# algorithm-specific
from .pg import PG
from .ppg import PPG
diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py
new file mode 100644
index 0000000000..747b1d1d8f
--- /dev/null
+++ b/ding/model/template/lm_vac.py
@@ -0,0 +1,202 @@
+from typing import Dict
+import torch
+import torch.nn as nn
+try:
+ from transformers import LlamaTokenizer
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
+except ImportError:
+ from ditk import logging
+ logging.warning("Not found transformer, please install it using: pip install transformers")
+
+from ding.model.common import top_p_logits
+from ding.reward_model import LlamaRewardModel
+from ding.utils import MODEL_REGISTRY
+
+
+def get_tokenizer(path: str):
+ """
+ Overview:
+ Return the pretrained tokenizer using the given path.
+ """
+ tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True)
+ tokenizer.bos_token = ''
+ tokenizer.eos_token = ''
+ tokenizer.pad_token = ''
+ tokenizer.pad_token_id = 0
+ tokenizer.unk_token = tokenizer.pad_token
+ tokenizer.unk_token_id = tokenizer.pad_token_id
+
+ return tokenizer
+
+
+class Llama(LlamaForCausalLM):
+
+ def __init__(self, config, opt, tokenizer, enable_checkpointing):
+ super().__init__(config)
+ self.opt = opt
+ self.tokenizer = tokenizer
+ self.enable_checkpointing = enable_checkpointing
+
+ def forward(self, decoder_input, incr_state=None, is_train=True):
+
+ attention_mask = decoder_input.ne(self.tokenizer.pad_token_id)
+ if incr_state is not None:
+ decoder_input = decoder_input[:, -1:]
+
+ output = super().forward(
+ input_ids=decoder_input,
+ attention_mask=attention_mask,
+ past_key_values=incr_state,
+ return_dict=True,
+ use_cache=not is_train
+ )
+
+ logits = output.logits
+ new_incr_states = output.past_key_values
+
+ return logits, new_incr_states
+
+ @torch.no_grad()
+ def generate(self, batch, **kwargs):
+ """
+ Generate response
+ """
+ if self.enable_checkpointing:
+ self.gradient_checkpointing_disable()
+ maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res)
+ temperature = kwargs.pop('temperature', self.opt.temperature)
+ repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty)
+ topp = kwargs.pop('topp', self.opt.topp)
+
+ decoder_input: torch.LongTensor = batch # (bsz, ...)
+ assert decoder_input[:, -1].ne(
+ self.tokenizer.pad_token_id
+ ).all(), 'Last token should not be a padding token (you can use left padding instead).'
+
+ dev = decoder_input.device
+ bsz = decoder_input.size(0)
+
+ scores = torch.zeros((bsz, ), device=dev, dtype=torch.float16)
+ done = torch.zeros((bsz, ), device=dev).to(torch.bool)
+
+ inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1)
+ decoder_input = torch.index_select(decoder_input, 0, inds)
+ init_length = decoder_input.size(1)
+
+ incr_state = None
+ for _token in range(maxlen_res):
+ if done.all():
+ break
+ score, incr_state, *_ = self.forward(decoder_input, incr_state, is_train=False)
+ score = score.half()
+
+ # now score is bs, len, vocab_size
+ score = score[:, -1, :]
+
+ # calculate repetition penalty
+ if repetition_penalty > 1.:
+ penalty_tokens = decoder_input[:, init_length:]
+ penalty_scores = torch.gather(score, dim=1, index=penalty_tokens)
+ penalty_scores = torch.where(
+ penalty_scores < 0., penalty_scores * repetition_penalty, penalty_scores / repetition_penalty
+ )
+ score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores)
+
+ # nucleus sampling
+ score = torch.softmax(score.div(temperature), dim=-1)
+ probs = top_p_logits(score, topp=topp, filter_value=0)
+ tok_ids = torch.multinomial(probs, 1)[:, 0]
+ hyp_ids = torch.arange(probs.size(0), device=dev)
+ scores = scores + probs[hyp_ids, tok_ids].log() * ~done
+
+ tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids)
+ decoder_input = torch.cat((decoder_input, tok_ids.unsqueeze(-1)), dim=-1)
+ done = done | tok_ids.eq(self.tokenizer.eos_token_id)
+
+ incr_state = self._reorder_cache(incr_state, hyp_ids)
+
+ # get all finalized candidates for each sample
+ decoder_input = decoder_input[:, init_length:]
+ decoder_input = decoder_input.view(bsz, -1)
+ scores = scores.view(bsz, )
+
+ lengths = decoder_input.ne(self.tokenizer.pad_token_id).sum(dim=-1)
+
+ length_penalty = torch.pow(lengths, 1.0)
+ scores /= length_penalty
+
+ preds_scores = []
+ for i in range(bsz):
+ seq: torch.LongTensor = decoder_input[i, :lengths[i, ]]
+ res_scores = (float(scores[i, ]), seq.tolist())
+ preds_scores.append([res_scores])
+
+ best_preds_scores = [preds[0] for preds in preds_scores]
+ if self.enable_checkpointing:
+ self.gradient_checkpointing_enable()
+ return best_preds_scores, preds_scores
+
+
+@MODEL_REGISTRY.register('llamavac')
+class LlamaVAC(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of Llama VAC. The actor and critic of this model are respectively \
+ a Llama Pretrained Model.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
+
+ def __init__(
+ self,
+ actor_path: str,
+ critic_path: str,
+ tokenizer_path: str,
+ opt: Dict,
+ enable_checkpointing: bool = True
+ ) -> None:
+ """
+ Overview:
+ Initialize the ``LlamaVAC`` model according to arguments.
+ Arguments:
+ - actor_path (:obj:`str`): Pretrained model path for actor.
+ - critic_path (:obj:`str`): Pretrained model path for critic.
+ - opt (:obj:`Dict`): Options for this model.
+ """
+ super(LlamaVAC, self).__init__()
+ tokenizer = get_tokenizer(tokenizer_path)
+ self.enable_checkpointing = enable_checkpointing
+
+ self.actor = Llama.from_pretrained(
+ actor_path,
+ opt=opt,
+ tokenizer=tokenizer,
+ torch_dtype=torch.bfloat16,
+ enable_checkpointing=enable_checkpointing
+ )
+
+ self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
+
+ if enable_checkpointing:
+ self.actor.gradient_checkpointing_enable()
+ self.critic.gradient_checkpointing_enable()
+
+ def forward(self, x: torch.Tensor, mode: str) -> Dict:
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(x)
+
+ def compute_actor(self, x):
+ policy_output = self.actor(decoder_input=x)
+ policy_logit, *_ = policy_output
+ return {"logit": policy_logit}
+
+ def compute_critic(self, x):
+ values = self.critic(decoder_input=x, only_last=False)
+ return {"value": values}
+
+ def compute_actor_critic(self, x):
+ policy_output = self.actor(decoder_input=x)
+ policy_logit, *_ = policy_output
+ values = self.critic(decoder_input=x, only_last=False)
+ return {"logit": policy_logit, "value": values}
diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py
index 81e605384c..3dbfb05581 100644
--- a/ding/policy/ppof.py
+++ b/ding/policy/ppof.py
@@ -1,4 +1,4 @@
-from typing import List, Dict, Any, Tuple, Union, Callable, Optional
+from typing import List, Dict, Any, Callable, Optional
from collections import namedtuple
from easydict import EasyDict
import copy
@@ -7,10 +7,13 @@
import torch
import treetensor.torch as ttorch
from torch.optim import AdamW
+from torch.cuda.amp import GradScaler
+from torch import autocast
from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, gae, gae_data, ppo_error_continuous, \
get_gae, ppo_policy_error_continuous, ArgmaxSampler, MultinomialSampler, ReparameterizationSampler, MuSampler, \
HybridStochasticSampler, HybridDeterminsticSampler, value_transform, value_inv_transform, symlog, inv_symlog
+from ding.rl_utils.gae import episodic_gae_data, episodic_gae
from ding.utils import POLICY_REGISTRY, RunningMeanStd
@@ -37,6 +40,7 @@ class PPOFPolicy:
value_norm='baseline',
ppo_param_init=True,
grad_norm=0.5,
+ chat_data=True,
# collect
n_sample=128,
unroll_len=1,
@@ -58,8 +62,15 @@ def default_model(cls: type) -> Callable:
from .model import PPOFModel
return PPOFModel
- def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None) -> None:
+ def __init__(
+ self,
+ cfg: "EasyDict",
+ model: torch.nn.Module,
+ enable_mode: List[str] = None,
+ orig_model: torch.nn.Module = None
+ ) -> None:
self._cfg = cfg
+ self._orig_model = orig_model
if model is None:
self._model = self.default_model()
else:
@@ -151,69 +162,90 @@ def _model_param_init(self):
def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
return_infos = []
self._model.train()
- bs = self._cfg.batch_size
- data = data[:self._cfg.n_sample // bs * bs] # rounding
+ if not self._cfg.chat_data:
+ bs = self._cfg.batch_size
+ data = data[:self._cfg.n_sample // bs * bs] # rounding
# outer training loop
for epoch in range(self._cfg.epoch_per_collect):
# recompute adv
with torch.no_grad():
- # get the value dictionary
- # In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred'
- value = self._model.compute_critic(data.obs)
- next_value = self._model.compute_critic(data.next_obs)
- reward = data.reward
-
- assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'],\
- 'Not supported value normalization! Value normalization supported: \
- popart, value rescale, symlog, baseline'
-
- if self._cfg.value_norm == 'popart':
- unnormalized_value = value['unnormalized_pred']
- unnormalized_next_value = value['unnormalized_pred']
-
- mu = self._model.critic_head.popart.mu
- sigma = self._model.critic_head.popart.sigma
- reward = (reward - mu) / sigma
-
- value = value['pred']
- next_value = next_value['pred']
- elif self._cfg.value_norm == 'value_rescale':
- value = value_inv_transform(value['pred'])
- next_value = value_inv_transform(next_value['pred'])
- elif self._cfg.value_norm == 'symlog':
- value = inv_symlog(value['pred'])
- next_value = inv_symlog(next_value['pred'])
- elif self._cfg.value_norm == 'baseline':
- value = value['pred'] * self._running_mean_std.std
- next_value = next_value['pred'] * self._running_mean_std.std
-
- traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
- adv_data = gae_data(value, next_value, reward, data.done, traj_flag)
- data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda)
-
- unnormalized_returns = value + data.adv # In popart, this return is normalized
-
- if self._cfg.value_norm == 'popart':
- self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1))
- elif self._cfg.value_norm == 'value_rescale':
- value = value_transform(value)
- unnormalized_returns = value_transform(unnormalized_returns)
- elif self._cfg.value_norm == 'symlog':
- value = symlog(value)
- unnormalized_returns = symlog(unnormalized_returns)
- elif self._cfg.value_norm == 'baseline':
- value /= self._running_mean_std.std
- unnormalized_returns /= self._running_mean_std.std
- self._running_mean_std.update(unnormalized_returns.cpu().numpy())
- data.value = value
- data.return_ = unnormalized_returns
+ if self._cfg.chat_data:
+ # [B, T]
+ value = self._model.compute_critic(data.obs)['value']
+ self._model.cpu()
+ self._orig_model.cuda()
+ data.orig_logit = self._orig_model.compute_actor(data.obs)['logit']
+ self._orig_model.cpu()
+ self._model.cuda()
+ data.value = value
+ reward = data.reward
+
+ traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
+ done = data.get('done', None)
+ adv_data = episodic_gae_data(value, data.mask, reward, done, traj_flag)
+ data.adv = episodic_gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda)
+
+ unnormalized_returns = data.value + data.adv
+ data.return_ = unnormalized_returns
+ else:
+ # get the value dictionary
+ # In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred'
+ value = self._model.compute_critic(data.obs)
+ next_value = self._model.compute_critic(data.next_obs)
+ reward = data.reward
+
+ assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'], \
+ 'Not supported value normalization! Value normalization supported: \
+ popart, value rescale, symlog, baseline'
+
+ if self._cfg.value_norm == 'popart':
+ unnormalized_value = value['unnormalized_pred']
+ unnormalized_next_value = value['unnormalized_pred']
+
+ mu = self._model.critic_head.popart.mu
+ sigma = self._model.critic_head.popart.sigma
+ reward = (reward - mu) / sigma
+
+ value = value['pred']
+ next_value = next_value['pred']
+ elif self._cfg.value_norm == 'value_rescale':
+ value = value_inv_transform(value['pred'])
+ next_value = value_inv_transform(next_value['pred'])
+ elif self._cfg.value_norm == 'symlog':
+ value = inv_symlog(value['pred'])
+ next_value = inv_symlog(next_value['pred'])
+ elif self._cfg.value_norm == 'baseline':
+ value = value['pred'] * self._running_mean_std.std
+ next_value = next_value['pred'] * self._running_mean_std.std
+
+ traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
+ adv_data = gae_data(value, next_value, reward, data.done, traj_flag)
+ data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda)
+
+ unnormalized_returns = value + data.adv # In popart, this return is normalized
+
+ if self._cfg.value_norm == 'popart':
+ self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1))
+ elif self._cfg.value_norm == 'value_rescale':
+ value = value_transform(value)
+ unnormalized_returns = value_transform(unnormalized_returns)
+ elif self._cfg.value_norm == 'symlog':
+ value = symlog(value)
+ unnormalized_returns = symlog(unnormalized_returns)
+ elif self._cfg.value_norm == 'baseline':
+ value /= self._running_mean_std.std
+ unnormalized_returns /= self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_returns.cpu().numpy())
+ data.value = value
+ data.return_ = unnormalized_returns
# inner training loop
split_data = ttorch.split(data, self._cfg.batch_size)
random.shuffle(list(split_data))
for batch in split_data:
- output = self._model.compute_actor_critic(batch.obs)
+ if not self._cfg.chat_data:
+ output = self._model.compute_actor_critic(batch.obs)
adv = batch.adv
if self._cfg.adv_norm:
# Normalize advantage in a train_batch
@@ -226,10 +258,27 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
)
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._cfg.clip_ratio)
elif self._action_space == 'discrete':
- ppo_batch = ppo_data(
- output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None
- )
- ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
+ if not self._cfg.chat_data:
+ ppo_batch = ppo_data(
+ output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, mask
+ )
+ ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
+ else:
+ with autocast(device_type='cuda', dtype=torch.float16):
+ output = self._model.compute_actor_critic(batch.obs)
+ mask = batch.mask
+ ppo_batch = ppo_data(
+ output['logit'], batch.orig_logit, batch.obs, output['value'], batch.value, adv,
+ batch.return_, None
+ )
+ ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
+ kl_loss = (
+ torch.nn.functional.kl_div(
+ torch.softmax(output["logit"], dim=-1),
+ torch.softmax(batch.orig_logit, dim=-1),
+ reduction='none'
+ ) * mask.unsqueeze(-1)
+ ).mean()
elif self._action_space == 'hybrid':
# discrete part (discrete policy loss and entropy loss)
ppo_discrete_batch = ppo_policy_data(
@@ -253,13 +302,20 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
)
- wv, we = self._cfg.value_weight, self._cfg.entropy_weight
- total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
-
- self._optimizer.zero_grad()
- total_loss.backward()
- torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm)
- self._optimizer.step()
+ if not self._cfg.chat_data:
+ wv, we = self._cfg.value_weight, self._cfg.entropy_weight
+ total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm)
+ self._optimizer.step()
+ else:
+ wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight
+ total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss
+ output = ttorch.as_tensor(output)
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
return_info = {
'cur_lr': self._optimizer.defaults['lr'],
@@ -267,6 +323,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
'policy_loss': ppo_loss.policy_loss.item(),
'value_loss': ppo_loss.value_loss.item(),
'entropy_loss': ppo_loss.entropy_loss.item(),
+ 'kl_loss': kl_loss.item(),
'adv_max': adv.max().item(),
'adv_mean': adv.mean().item(),
'value_mean': output.value.mean().item(),
diff --git a/ding/reward_model/__init__.py b/ding/reward_model/__init__.py
index 4538102861..5b197af25d 100644
--- a/ding/reward_model/__init__.py
+++ b/ding/reward_model/__init__.py
@@ -13,3 +13,5 @@
from .guided_cost_reward_model import GuidedCostRewardModel
from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel
from .icm_reward_model import ICMRewardModel
+# RLHF
+from .language_reward_model import LlamaRewardModel
diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py
new file mode 100644
index 0000000000..ff2558099a
--- /dev/null
+++ b/ding/reward_model/language_reward_model.py
@@ -0,0 +1,28 @@
+import torch
+try:
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
+except ImportError:
+ from ditk import logging
+ logging.warning("Not found transformer, please install it using: pip install transformers")
+
+
+class LlamaRewardModel(LlamaForCausalLM):
+
+ def __init__(self, config, tokenizer):
+ super().__init__(config)
+ self.tokenizer = tokenizer
+ self.reward_head = torch.nn.Linear(config.hidden_size, 1, bias=False)
+
+ def forward(self, decoder_input, only_last=True):
+ attention_mask = decoder_input.ne(self.tokenizer.pad_token_id)
+ with torch.no_grad():
+ output = self.model.forward(
+ input_ids=decoder_input, attention_mask=attention_mask, return_dict=True, use_cache=False
+ )
+
+ if only_last:
+ logits = self.reward_head(output.last_hidden_state[:, -1, :]).squeeze(-1)
+ else:
+ logits = self.reward_head(output.last_hidden_state).squeeze(-1)
+
+ return logits
diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py
index 800fcae354..16b8313750 100644
--- a/ding/rl_utils/gae.py
+++ b/ding/rl_utils/gae.py
@@ -3,6 +3,7 @@
from ding.hpc_rl import hpc_wrapper
gae_data = namedtuple('gae_data', ['value', 'next_value', 'reward', 'done', 'traj_flag'])
+episodic_gae_data = namedtuple('episodic_gae_data', ['value', 'mask', 'reward', 'done', 'traj_flag'])
def shape_fn_gae(args, kwargs):
@@ -68,3 +69,24 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F
gae_item = delta[t] + factor[t] * gae_item
adv[t] = gae_item
return adv
+
+
+def episodic_gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97):
+ value, mask, reward, done, traj_flag = data
+ if done is None:
+ done = torch.zeros_like(value)
+ if traj_flag is None:
+ traj_flag = done
+ advs = []
+ bsz = value.shape[0]
+ for i in range(bsz):
+ val, mas, rew, don, traj = value[i], mask[i], reward[i], done[i], traj_flag[i]
+ next_val = torch.zeros_like(val)
+ next_val[:-1] = val[1:]
+ reward = torch.zeros_like(val)
+ reward[-1] = rew
+ gd = gae_data(
+ val.unsqueeze(-1), next_val.unsqueeze(-1), reward.unsqueeze(-1), don.unsqueeze(-1), traj.unsqueeze(-1)
+ )
+ advs.append(gae(gd, gamma, lambda_).squeeze(-1))
+ return torch.stack(advs, dim=0)
diff --git a/dizoo/chat/__init__.py b/dizoo/chat/__init__.py
new file mode 100644
index 0000000000..eb1eb48abb
--- /dev/null
+++ b/dizoo/chat/__init__.py
@@ -0,0 +1 @@
+from .env import ChatEnv
diff --git a/dizoo/chat/entry.py b/dizoo/chat/entry.py
new file mode 100644
index 0000000000..5da6ac620d
--- /dev/null
+++ b/dizoo/chat/entry.py
@@ -0,0 +1,34 @@
+from easydict import EasyDict
+
+from ding.bonus.ppof import PPOF
+from ding.model.template import LlamaVAC
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--actor_path', type=str)
+ parser.add_argument('--critic_path', type=str)
+ parser.add_argument('--tokenizer_path', type=str)
+ args = parser.parse_args()
+
+ opt = EasyDict({
+ "maxlen_res": 512,
+ "temperature": 1,
+ "repetition_penalty": 1,
+ "topp": 0.8
+ })
+
+ model = LlamaVAC(
+ actor_path=args.actor_path,
+ critic_path=args.critic_path,
+ tokenizer_path=args.tokenizer_path,
+ opt=opt
+ )
+
+ policy = PPOF(
+ env_id="chat",
+ exp_name="rlhf-ppo",
+ model=model
+ )
+ policy.train(collector_env_num=1, evaluator_env_num=1, debug=True)
diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py
new file mode 100644
index 0000000000..01d0f183cb
--- /dev/null
+++ b/dizoo/chat/env.py
@@ -0,0 +1,77 @@
+import torch
+
+from ding.envs import BaseEnv
+from ding.reward_model import LlamaRewardModel
+from .utils import OnlyPromptDataset, concat_context_and_response, get_tokenizer, pad_sequences
+
+
+class ChatEnv(BaseEnv):
+ def __init__(
+ self,
+ batch_size: int,
+ reward_model_path: str,
+ tokenizer_path: str,
+ data_path: str,
+ maxlen_prompt: int,
+ maxlen_res: int,
+ ):
+ self.batch_size = batch_size
+ self.tokenizer = get_tokenizer(tokenizer_path)
+ self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer)
+ self.action_space = None
+ self.observation_space = None
+ self.reward_space = None
+
+ self._init_flag = False
+ self._seed = None
+
+ self.dataset = OnlyPromptDataset(
+ data_path=data_path,
+ tokenizer=self.tokenizer,
+ batch_size=batch_size,
+ maxlen_prompt=maxlen_prompt,
+ maxlen_res=maxlen_res,
+ mode='train',
+ )
+ self.generator = self.dataset.final_generator()
+ self.last_batch = None
+
+ def close(self) -> None:
+ self._init_flag = False
+
+ def reset(self):
+ self.last_batch = next(self.generator)
+ if self.last_batch is None:
+ self.generator = self.dataset.final_generator()
+ self.last_batch = next(self.generator)
+ self._init_flag = True
+ return self.last_batch
+
+ def __repr__(self) -> str:
+ return "DI-engine Chat Env"
+
+ def seed(self, seed):
+ self._seed = seed
+
+ def clone(self, caller):
+ # It should not create a new copy, since the language model is initialized.
+ return self
+
+ def step(self, action):
+ """
+ For each step, this env will return a batch of prompts. These prompts a vectorized by using tokenizer, and are \
+ padded into the same length.
+ """
+ output_mask, output_vec = concat_context_and_response(self.tokenizer, self.last_batch['text_vec'].tolist(), action)
+ output_vec = pad_sequences(output_vec, self.tokenizer.pad_token_id, padding='left')
+ rm_input = torch.tensor(output_vec, dtype=torch.long)
+ output_mask = pad_sequences(output_mask, self.tokenizer.pad_token_id, padding='left')
+ with torch.no_grad():
+ rew = self.rm(rm_input)
+
+ self.last_batch = next(self.generator)
+ if self.last_batch is None:
+ self.generator = self.dataset.final_generator()
+ self.last_batch = next(self.generator)
+
+ return output_mask, output_vec, rew
diff --git a/dizoo/chat/utils.py b/dizoo/chat/utils.py
new file mode 100644
index 0000000000..be05e76805
--- /dev/null
+++ b/dizoo/chat/utils.py
@@ -0,0 +1,244 @@
+import json
+import os
+from typing import List, Dict, Any, Tuple
+import warnings
+
+from transformers.models.llama.tokenization_llama import LlamaTokenizer
+from torch.utils.data.dataset import IterableDataset
+import torch
+import random
+
+
+# Prefix of human sentence and assistant sentence.
+HUMAN_PROMPT = "Human:"
+ASSISTANT_PROMPT = "Assistant:"
+
+
+def strip_pad_token_id(tokenizer: LlamaTokenizer, seq: List[int]):
+ """
+ Overview:
+ Remove ``pad_token_id`` in a sequence.
+ """
+ return [tok for tok in seq if tok != tokenizer.pad_token_id]
+
+
+def concat_context_and_response(
+ tokenizer: LlamaTokenizer,
+ context: List[List[int]],
+ responses: List[List[Tuple[float, List[int]]]]
+):
+ """
+ Overview:
+ Given the batched input prompts and responses, concatenate them together.
+ """
+ assert len(context) == len(responses), f'Size not match: {len(context)} and {len(responses)}'
+
+ total_context, total_response = [], []
+ total_context_mask, total_response_mask = [], []
+ for _context, _response in zip(context, responses):
+ # Each ``_context`` is a single input prompt.
+ _context = strip_pad_token_id(tokenizer, _context)
+ for resp in _response:
+ # Each ``resp`` is a single response.
+ resp = resp[0][1]
+ resp = strip_pad_token_id(tokenizer, resp)
+ if resp[-1] != tokenizer.eos_token_id:
+ warnings.warn(
+ f'Generated response is too long: {tokenizer.decode(_context + resp, skip_special_tokens=False)}')
+
+ total_context.append(_context.copy())
+ total_context_mask.append([0] * len(_context))
+ total_response.append(resp)
+ total_response_mask.append([1] * len(resp))
+
+ total_gene_samples_vec = [c + r for c, r in zip(total_context, total_response)]
+ total_gene_samples_mask = [c + r for c, r in zip(total_context_mask, total_response_mask)]
+ return total_gene_samples_mask, total_gene_samples_vec
+
+
+def pad_sequences(
+ seqs: List[List[int]],
+ pad_value: int,
+ padding: str = 'right'):
+ """
+ Overview:
+ Padding sequence to the same length
+ """
+ max_len = max(len(seq) for seq in seqs)
+ if padding == 'right':
+ padded_seqs = [seq + [pad_value] * (max_len - len(seq)) for seq in seqs]
+ elif padding == 'left':
+ padded_seqs = [[pad_value] * (max_len - len(seq)) + seq for seq in seqs]
+ else:
+ raise ValueError
+ return padded_seqs
+
+
+def get_special_prompt(i: int):
+ return HUMAN_PROMPT if i % 2 == 0 else ASSISTANT_PROMPT
+
+
+def get_model_prompt(context: List[str], eos_token=""):
+ human_prompt, assistant_prompt = HUMAN_PROMPT, ASSISTANT_PROMPT
+ if context[-1].startswith(human_prompt):
+ end_prompt = assistant_prompt
+ elif context[-1].startswith(assistant_prompt):
+ end_prompt = human_prompt
+ else:
+ raise ValueError
+
+ context = eos_token.join(context)
+ return f'{context}{eos_token}{end_prompt}'
+
+
+def get_tokenizer(path: str):
+ """
+ Overview:
+ Return the pretrained tokenizer using the given path.
+ """
+ tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True)
+ tokenizer.bos_token = ''
+ tokenizer.eos_token = ''
+ tokenizer.pad_token = ''
+ tokenizer.pad_token_id = 0
+ tokenizer.unk_token = tokenizer.pad_token
+ tokenizer.unk_token_id = tokenizer.pad_token_id
+
+ return tokenizer
+
+
+class OnlyPromptDataset(IterableDataset):
+ """
+ Overview:
+ Dataset that only contains the prompts of the raw data (no answer).
+ """
+ def __init__(
+ self,
+ data_path: os.PathLike,
+ tokenizer,
+ batch_size: int,
+ maxlen_prompt: int,
+ maxlen_res: int,
+ mode: str = 'train',
+ ) -> None:
+ super().__init__()
+ self.mode = mode
+ self.tokenizer = tokenizer
+ self.maxlen_prompt = maxlen_prompt
+ self.maxlen_res = maxlen_res
+ self.batch_size = batch_size
+
+ # Load data.
+ self.data = []
+ files = sorted([file for file in os.listdir(data_path) if file.endswith(f'{mode}.json')])
+ for file in files:
+ file_path = os.path.join(data_path, file)
+ tmp_data = []
+ try:
+ tmp_data = self.load_data(file_path)
+ except Exception as e:
+ pass
+ self.data.extend(tmp_data)
+
+ # Set the length of this dataset.
+ self.size = len(self.data)
+
+ def __len__(self):
+ return self.size
+
+ def load_data(self, file_path: str):
+ """
+ Overview:
+ Load raw data from given file_path.
+ """
+ with open(file_path, 'r') as f:
+ data: List[List[str]] = json.load(f)
+
+ output: List[List[str]] = [sample for sample in data if all(sample)]
+ del data
+
+ return output
+
+ def final_generator(self):
+ data_generator = self.batch_generator()
+ for batch_samples in data_generator:
+ batch = self.batchify(batch_samples)
+ yield batch
+
+ def __iter__(self):
+ return self.final_generator()
+
+ def format(self, sample: List[str]) -> Dict[str, Any]:
+ """
+ Overview:
+ Convert one data sample in to string.
+ """
+ context = sample
+ context = [get_special_prompt(i + (len(context) + 1) % 2) + s for i, s in enumerate(context)]
+ context_vec = self.tokenizer.encode(get_model_prompt(context, self.tokenizer.eos_token),
+ add_special_tokens=True)
+
+ # truncate to max_len
+ while len(context_vec) > self.maxlen_prompt - self.maxlen_res and len(context) > 1:
+ context = context[1:]
+ context_vec = self.tokenizer.encode(get_model_prompt(context, self.tokenizer.eos_token),
+ add_special_tokens=True)
+
+ output = {
+ 'text': self.tokenizer.decode(context_vec, skip_special_tokens=False),
+ 'text_vec': context_vec
+ }
+
+ return output
+
+ def batchify(self, batch_samples: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Batchify a list of ids by padding their shape to be the same.
+ """
+ batch_text_vec = torch.tensor(pad_sequences(
+ [sample['text_vec'] for sample in batch_samples], pad_value=self.tokenizer.pad_token_id, padding='left'
+ ), dtype=torch.long)
+ return {
+ 'text_vec': batch_text_vec,
+ 'text': [sample['text'] for sample in batch_samples]
+ }
+
+ def sample_generator(self):
+ """
+ Overview:
+ Generate a single data sample.
+ """
+ random.seed(None)
+ if self.mode == 'train':
+ random.shuffle(self.data)
+
+ for sample in self.data:
+ yield self.format(sample)
+
+ def _batch_generator(self):
+ """
+ Overview:
+ Generate a batch of samples.
+ """
+ batch = []
+ # Generate a sample.
+ for sample in self.sample_generator():
+ sample_len = len(sample['text_vec'])
+ if sample_len > self.maxlen_prompt:
+ continue
+
+ batch.append(sample)
+ if len(batch) >= self.batch_size:
+ yield batch[:self.batch_size]
+ batch = batch[self.batch_size:]
+ if batch:
+ yield batch
+
+ def batch_generator(self):
+ while True:
+ for batch in self._batch_generator():
+ if len(batch) == self.batch_size:
+ yield batch
+ if self.mode != 'train':
+ break
diff --git a/launch_ppof.py b/launch_ppof.py
new file mode 100644
index 0000000000..53825dcba2
--- /dev/null
+++ b/launch_ppof.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+from transformers import LlamaTokenizer
+
+from ding.bonus.ppof import PPOF
+from ding.model.template.vac import LlamaVAC
+
+
+def get_tokenizer(path: str):
+ """
+ Overview:
+ Return the pretrained tokenizer using the given path.
+ """
+ tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True)
+ tokenizer.bos_token = ''
+ tokenizer.eos_token = ''
+ tokenizer.pad_token = ''
+ tokenizer.pad_token_id = 0
+ tokenizer.unk_token = tokenizer.pad_token
+ tokenizer.unk_token_id = tokenizer.pad_token_id
+
+ return tokenizer
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--actor_path', type=str)
+ parser.add_argument('--critic_path', type=str)
+ args = parser.parse_args()
+
+ opt = EasyDict({
+ "maxlen_res": 512,
+ "temperature": 1,
+ "repetition_penalty": 1,
+ "topp": 0.8
+ })
+
+ model = LlamaVAC(
+ actor_path=args.actor_path,
+ critic_path=args.critic_path,
+ tokenizer=get_tokenizer("/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-sft-model-7B-en"),
+ opt=opt
+ )
+
+ policy = PPOF(
+ env_id="chat",
+ exp_name="rlhf-ppo",
+ model=model
+ )
+ policy.train(collector_env_num=1, evaluator_env_num=1, debug=True)