diff --git a/generate/utils.py b/generate/utils.py index e05ba98..19d3272 100644 --- a/generate/utils.py +++ b/generate/utils.py @@ -370,7 +370,6 @@ def clean_output(self, output: str, prompt: str) -> str: class InstructConfig(InferenceConfig): - def __init__(self, prompted : bool = False, instruction_tag : str = "### Instruction", response_tag : str = "### Response"): super().__init__(prompted=prompted) self.instruction_tag = instruction_tag @@ -401,6 +400,63 @@ def format_prompt(self, prompt : str) -> str: def clean_output(self, output: str, prompt: str) -> str: return clean_instruct_output(output, prompt, self.response_tag) +class QwenConfig(InferenceConfig): + def __init__(self, prompted : bool = False): + super().__init__(prompted=prompted) + + def get_dtype(self): + return torch.float16 + + def init_padding(self, tokenizer): + tokenizer.pad_token_id = tokenizer.eos_token_id # for batching + tokenizer.padding_side = "left" # for decoder-only models + + def get_pad_token_id(self, tokenizer) -> int: + return tokenizer.eos_token_id + + def get_eos_token_id(self, tokenizer) -> int: + return None + + def trust_remote_code(self) -> bool: + return False + + def format_prompt(self, prompt : str) -> str: + if self.prompted: + return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}" + return prompt.strip() + + def clean_output(self, output: str, prompt: str) -> str: + return clean_output(output, prompt) + +class ChatMLConfig(InferenceConfig): + def __init__(self, prompted : bool = False): + super().__init__(prompted=prompted) + + def get_dtype(self): + return torch.bfloat16 + + def init_padding(self, tokenizer): + tokenizer.pad_token_id = tokenizer.eos_token_id # for batching + tokenizer.padding_side = "left" # for decoder-only models + + def get_pad_token_id(self, tokenizer) -> int: + return tokenizer.pad_token_id + + def get_eos_token_id(self, tokenizer) -> int: + return tokenizer.eos_token_id + + def trust_remote_code(self) -> bool: + return False + + def format_prompt(self, prompt : str) -> str: + function_name = get_function_name(prompt, "cuda" if "__global__" in prompt else "serial") + prompt = f"Complete the following c++ function.\n```c++{prompt.strip()}```\nWrite only the function {function_name} and no other code. Enclose your solution in ```c++ and ```." + prompt = f"<|im_start|>system\nYou are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + return prompt + + def clean_output(self, output: str, prompt: str) -> str: + return clean_instruct_output(output, prompt,"<|im_start|>assistant\n") + def get_inference_config(model_name : str, **kwargs) -> InferenceConfig: if model_name == "bigcode/starcoderbase": return StarCoderConfig(**kwargs) @@ -422,6 +478,12 @@ def get_inference_config(model_name : str, **kwargs) -> InferenceConfig: return InstructConfig(instruction_tag='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:', response_tag='### Response:', **kwargs) elif model_name.startswith('hpcgroup/rlpf'): return InstructConfig(instruction_tag='### Instruction', response_tag='### Response', **kwargs) + elif model_name.startswith('Qwen/Qwen2.5') and 'Instruct' in model_name: + return ChatMLConfig(**kwargs) + elif model_name.startswith('Qwen/Qwen3'): + return ChatMLConfig(**kwargs) + elif model_name.startswith('Qwen/Qwen2.5'): + return QwenConfig(**kwargs) else: raise ValueError(f"Unknown model name: {model_name}")