Skip to content

Commit dcd45d5

Browse files
committed
feat: support add tokens to tokenizer.
* Resize the model by-default * Adding special tokens is ignored by the decode phase of the PPO. This is because it needs to skip certain special tokens, such as EOS tokens. Therefore only add normal tokens.
1 parent 74ea532 commit dcd45d5

File tree

4 files changed

+30
-14
lines changed

4 files changed

+30
-14
lines changed

trlx/models/modeling_ppo.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def __init__(
411411

412412
# The branch is defined by the last `num_layers_unfrozen` layers of the pretrained model
413413
decoder_blocks = deepcopy(hf_get_decoder_blocks(base_model))
414+
self.embed_tokens = base_model.get_input_embeddings()
414415
self.decoder_blocks = nn.ModuleList(list(decoder_blocks)[-num_layers_unfrozen:])
415416
self.final_norm = deepcopy(hf_get_decoder_final_norm(base_model))
416417
self.lm_head = deepcopy(hf_get_lm_head(base_model))
@@ -425,6 +426,18 @@ def __init__(
425426
for parameter in self.parameters():
426427
parameter.requires_grad_(False)
427428

429+
def set_input_embeddings(self, value):
430+
self.embed_tokens = value
431+
432+
def get_input_embeddings(self):
433+
return self.embed_tokens
434+
435+
def get_output_embeddings(self):
436+
return self.lm_head
437+
438+
def set_output_embeddings(self, new_embeddings):
439+
self.lm_head = new_embeddings
440+
428441

429442
class GPTModelBranch(ModelBranch):
430443
def forward( # noqa: max-complexity

trlx/trainer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
logit_mask=None,
4242
stop_sequences=None,
4343
train_mode=False,
44-
additional_special_tokens=None,
44+
additional_tokens=None,
4545
):
4646
self.store: BaseRolloutStore = None
4747
self.config = config
@@ -50,7 +50,7 @@ def __init__(
5050
self.train_mode = train_mode
5151
self.logit_mask = logit_mask
5252
self.stop_sequences = stop_sequences
53-
self.additional_special_tokens = additional_special_tokens
53+
self.additional_tokens = additional_tokens
5454

5555
def push_to_store(self, data):
5656
self.store.push(data)

trlx/trainer/accelerate_base_trainer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,22 @@ def __init__(self, config, **kwargs): # noqa: C901
7070
self.scheduler = self.setup_scheduler()
7171

7272
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path)
73+
self.tokenizer.add_tokens(self.additional_tokens)
74+
# resize the model by-default
75+
self.model.base_model.resize_token_embeddings(len(self.tokenizer))
76+
if hasattr(self.model, "frozen_head"):
77+
self.model.frozen_head.resize_token_embeddings(len(self.tokenizer))
78+
else:
79+
# resize a reference model when hydra heads are not used
80+
self.ref_model.resize_token_embeddings(len(self.tokenizer))
81+
7382
self.tokenizer.padding_side = config.tokenizer.padding_side
7483
self.tokenizer.truncation_side = config.tokenizer.truncation_side
7584
self.tokenizer.sep_token = "<sep>"
7685
if config.model.model_arch_type != "seq2seq":
7786
self.tokenizer.pad_token = self.tokenizer.eos_token
7887
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
7988

80-
if self.additional_special_tokens is not None and type(self.additional_special_tokens) is list:
81-
self.tokenizer.add_special_tokens(
82-
{"additional_special_tokens": self.additional_special_tokens}
83-
)
84-
self.model.base_model.resize_token_embeddings(len(self.tokenizer))
85-
8689
script_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0]
8790
if not isinstance(config.model.model_path, str):
8891
model_name = str(config.model.model_path).split()[0]

trlx/trlx.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import warnings
3-
from typing import Callable, Dict, Iterable, List, Optional, Tuple
3+
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
44

55
from trlx.data.configs import TRLConfig
66
from trlx.data.default_configs import (
@@ -23,7 +23,7 @@ def train( # noqa: C901
2323
metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None,
2424
config: Optional[TRLConfig] = None,
2525
stop_sequences: Optional[List[str]] = [],
26-
additional_special_tokens: Optional[List[str]] = None,
26+
additional_tokens: Optional[Union[str, List[str]]] = None,
2727
):
2828
"""
2929
Dispatches online, offline reinforcement training or supervised finetuning
@@ -55,9 +55,9 @@ def train( # noqa: C901
5555
stop_sequences (Optional[List[str]]):
5656
String sequences to trim generations (both for generating of experience and evaluation) up to its
5757
encounter in them. Generations will not contain them and also will also be right-stripped
58-
additional_special_tokens (Optional[List[str]]):
59-
A list of additional special tokens. Add them to the tokenizer to ensure they won’t be split by
60-
the tokenization process.
58+
additional_tokens (Optional[Union[str, List[str]]]):
59+
A list of additional tokens. The given tokens are added only if they don’t already exist
60+
in the vocabulary, each token then gets a new attributed id
6161
"""
6262
if config is None:
6363
warnings.warn(
@@ -85,7 +85,7 @@ def train( # noqa: C901
8585
reward_fn=reward_fn,
8686
metric_fn=metric_fn,
8787
stop_sequences=stop_sequences,
88-
additional_special_tokens=additional_special_tokens,
88+
additional_tokens=additional_tokens,
8989
**config.train.trainer_kwargs,
9090
)
9191

0 commit comments

Comments
 (0)