From 8036df5b55591e0212308462ff848d1fb3a7110c Mon Sep 17 00:00:00 2001 From: songt Date: Tue, 24 Oct 2023 11:31:35 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96SFT=E7=9A=84preprocess?= =?UTF-8?q?=E5=87=BD=E6=95=B0=EF=BC=8C=E4=BD=BF=E5=BE=97label=20mask?= =?UTF-8?q?=E7=AD=96=E7=95=A5=E6=9B=B4=E6=B8=85=E6=99=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- finetune.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/finetune.py b/finetune.py index 969aba5..5e70714 100644 --- a/finetune.py +++ b/finetune.py @@ -123,11 +123,13 @@ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: st def preprocess( - sources, + sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer, max_len: int, system_message: str = "You are a helpful assistant." ) -> Dict: + # 对于每轮输入的6部分<|im_start|> + role + \n + message + <|im_end|> + \n, + # 在label中被mask的有:1.role + \n,2.system和user的message。 roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"} im_start = tokenizer.im_start_id @@ -144,20 +146,27 @@ def preprocess( source = source[1:] input_id, target = [], [] + # system input = im_start + sys_role + \n + sys_msg + im_end + \n system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens input_id += system + # system target = im_start + mask of (role + \n + msg) + im_end + \n target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens assert len(input_id) == len(target) for j, sentence in enumerate(source): role = roles[sentence["from"]] - _input_id = tokenizer(role).input_ids + nl_tokens + \ - tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens + role_input_id = tokenizer(role).input_ids + value_input_id = tokenizer(sentence["value"]).input_ids + # input = im_start + role + \n + msg + im_end + \n + _input_id = role_input_id + nl_tokens + value_input_id + [im_end] + nl_tokens + input_id += _input_id if role == '<|im_start|>user': + # user target = im_start + mask of (role + \n + msg) + im_end + \n _target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens elif role == '<|im_start|>assistant': - _target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \ - _input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens + # assistant target = im_start + mask of (role + \n) + msg + im_end + \n + _target = [im_start] + [IGNORE_TOKEN_ID] * len(role_input_id) + \ + value_input_id + [im_end] + nl_tokens else: raise NotImplementedError target += _target