diff --git a/finetune.py b/finetune.py index 969aba5..5e70714 100644 --- a/finetune.py +++ b/finetune.py @@ -123,11 +123,13 @@ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: st def preprocess( - sources, + sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer, max_len: int, system_message: str = "You are a helpful assistant." ) -> Dict: + # 对于每轮输入的6部分<|im_start|> + role + \n + message + <|im_end|> + \n, + # 在label中被mask的有:1.role + \n,2.system和user的message。 roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"} im_start = tokenizer.im_start_id @@ -144,20 +146,27 @@ def preprocess( source = source[1:] input_id, target = [], [] + # system input = im_start + sys_role + \n + sys_msg + im_end + \n system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens input_id += system + # system target = im_start + mask of (role + \n + msg) + im_end + \n target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens assert len(input_id) == len(target) for j, sentence in enumerate(source): role = roles[sentence["from"]] - _input_id = tokenizer(role).input_ids + nl_tokens + \ - tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens + role_input_id = tokenizer(role).input_ids + value_input_id = tokenizer(sentence["value"]).input_ids + # input = im_start + role + \n + msg + im_end + \n + _input_id = role_input_id + nl_tokens + value_input_id + [im_end] + nl_tokens + input_id += _input_id if role == '<|im_start|>user': + # user target = im_start + mask of (role + \n + msg) + im_end + \n _target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens elif role == '<|im_start|>assistant': - _target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \ - _input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens + # assistant target = im_start + mask of (role + \n) + msg + im_end + \n + _target = [im_start] + [IGNORE_TOKEN_ID] * len(role_input_id) + \ + value_input_id + [im_end] + nl_tokens else: raise NotImplementedError target += _target