Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion generate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def clean_output(self, output: str, prompt: str) -> str:


class InstructConfig(InferenceConfig):

def __init__(self, prompted : bool = False, instruction_tag : str = "### Instruction", response_tag : str = "### Response"):
super().__init__(prompted=prompted)
self.instruction_tag = instruction_tag
Expand Down Expand Up @@ -401,6 +400,63 @@ def format_prompt(self, prompt : str) -> str:
def clean_output(self, output: str, prompt: str) -> str:
return clean_instruct_output(output, prompt, self.response_tag)

class QwenConfig(InferenceConfig):
def __init__(self, prompted : bool = False):
super().__init__(prompted=prompted)

def get_dtype(self):
return torch.float16

def init_padding(self, tokenizer):
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
tokenizer.padding_side = "left" # for decoder-only models

def get_pad_token_id(self, tokenizer) -> int:
return tokenizer.eos_token_id

def get_eos_token_id(self, tokenizer) -> int:
return None

def trust_remote_code(self) -> bool:
return False

def format_prompt(self, prompt : str) -> str:
if self.prompted:
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
return prompt.strip()

def clean_output(self, output: str, prompt: str) -> str:
return clean_output(output, prompt)

class ChatMLConfig(InferenceConfig):
def __init__(self, prompted : bool = False):
super().__init__(prompted=prompted)

def get_dtype(self):
return torch.bfloat16

def init_padding(self, tokenizer):
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
tokenizer.padding_side = "left" # for decoder-only models

def get_pad_token_id(self, tokenizer) -> int:
return tokenizer.pad_token_id

def get_eos_token_id(self, tokenizer) -> int:
return tokenizer.eos_token_id

def trust_remote_code(self) -> bool:
return False

def format_prompt(self, prompt : str) -> str:
function_name = get_function_name(prompt, "cuda" if "__global__" in prompt else "serial")
prompt = f"Complete the following c++ function.\n```c++{prompt.strip()}```\nWrite only the function {function_name} and no other code. Enclose your solution in ```c++ and ```."
prompt = f"<|im_start|>system\nYou are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
return prompt

def clean_output(self, output: str, prompt: str) -> str:
return clean_instruct_output(output, prompt,"<|im_start|>assistant\n")

def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
if model_name == "bigcode/starcoderbase":
return StarCoderConfig(**kwargs)
Expand All @@ -422,6 +478,12 @@ def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
return InstructConfig(instruction_tag='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:', response_tag='### Response:', **kwargs)
elif model_name.startswith('hpcgroup/rlpf'):
return InstructConfig(instruction_tag='### Instruction', response_tag='### Response', **kwargs)
elif model_name.startswith('Qwen/Qwen2.5') and 'Instruct' in model_name:
return ChatMLConfig(**kwargs)
elif model_name.startswith('Qwen/Qwen3'):
return ChatMLConfig(**kwargs)
elif model_name.startswith('Qwen/Qwen2.5'):
return QwenConfig(**kwargs)
else:
raise ValueError(f"Unknown model name: {model_name}")

Expand Down