diff --git a/common/config_models.py b/common/config_models.py index a31bd3ef..8833106a 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -271,6 +271,21 @@ class ModelConfig(BaseConfigModel): "Enables vision support if the model supports it. (default: False)" ), ) + reasoning: bool = Field( + False, + description=( + "Enable the reasoning parser (default: False).\n" + "Split response message into reasoning_content and content fields." + ), + ) + reasoning_start_token: str = Field( + "", + description=("Start token for the reasoning parser (default: )."), + ) + reasoning_end_token: str = Field( + "", + description=("End token for the reasoning parser (default: )."), + ) _metadata: Metadata = PrivateAttr(Metadata()) model_config = ConfigDict(protected_namespaces=()) diff --git a/config_sample.yml b/config_sample.yml index b6f362de..8c42f358 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -139,6 +139,16 @@ model: # Enables vision support if the model supports it. (default: False) vision: false + # Enable reasoning parser (default: False). + # Do NOT enable this if the model is not a reasoning model (e.g. deepseek-r1 series) + reasoning: false + + # The start token for reasoning conetnt (default: "") + reasoning_start_token: "" + + # The end token for reasoning conetnt (default: "") + reasoning_end_token: "" + # Options for draft models (speculative decoding) # This will use more VRAM! draft_model: diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 86a22477..f6d74d80 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -31,6 +31,7 @@ class ChatCompletionMessagePart(BaseModel): class ChatCompletionMessage(BaseModel): role: str = "user" content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None + reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = None tool_calls_json: SkipJsonSchema[Optional[str]] = None diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 2326bc2a..56c1d3b6 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -17,6 +17,7 @@ handle_request_error, request_disconnect_loop, ) +from common.tabby_config import config from common.utils import unwrap from endpoints.OAI.types.chat_completion import ( ChatCompletionLogprobs, @@ -33,6 +34,25 @@ from endpoints.OAI.utils.tools import ToolCallProcessor +def _extract_think_content(text: str) -> tuple[Optional[str], Optional[str]]: + """Extract content between tags and the remaining content. + Only available in none-streaming mode.""" + if ( + config.model.reasoning_start_token not in text + and config.model.reasoning_end_token not in text + ): + return None, text + elif config.model.reasoning_start_token in text: + start_reasoning = text.split(config.model.reasoning_start_token)[1] + reasoning_content = start_reasoning.split(config.model.reasoning_end_token)[0] + content = start_reasoning.split(config.model.reasoning_end_token)[1] + return reasoning_content.strip(), content.strip() + else: + reasoning_content = text.split(config.model.reasoning_end_token)[0] + content = text.split(config.model.reasoning_end_token)[1] + return reasoning_content.strip(), content.strip() + + def _create_response( request_id: str, generations: List[dict], model_name: Optional[str] ): @@ -43,9 +63,16 @@ def _create_response( choices = [] for index, generation in enumerate(generations): - message = ChatCompletionMessage( - role="assistant", content=unwrap(generation.get("text"), "") - ) + if config.model.reasoning: + raw_content = unwrap(generation.get("text"), "") + reasoning_content, content = _extract_think_content(raw_content) + message = ChatCompletionMessage( + role="assistant", reasoning_content=reasoning_content, content=content + ) + else: + message = ChatCompletionMessage( + role="assistant", content=unwrap(generation.get("text"), "") + ) tool_calls = generation["tool_calls"] if tool_calls: @@ -110,6 +137,7 @@ def _create_stream_chunk( generation: Optional[dict] = None, model_name: Optional[str] = None, is_usage_chunk: bool = False, + is_reasoning_chunk: bool = False, ): """Create a chat completion stream chunk from the provided text.""" @@ -144,8 +172,14 @@ def _create_stream_chunk( choices.append(choice) else: - message = ChatCompletionMessage( - role="assistant", content=unwrap(generation.get("text"), "") + message = ( + ChatCompletionMessage( + role="assistant", reasoning_content=unwrap(generation.get("text"), "") + ) + if is_reasoning_chunk + else ChatCompletionMessage( + role="assistant", content=unwrap(generation.get("text"), "") + ) ) logprob_response = None @@ -337,6 +371,8 @@ async def stream_generate_chat_completion( # We need to keep track of the text generated so we can resume the tool calls current_generation_text = "" + is_reasoning_chunk = config.model.reasoning + # Consumer loop while True: if disconnect_task.done(): @@ -364,8 +400,28 @@ async def stream_generate_chat_completion( if isinstance(generation, Exception): raise generation + if ( + unwrap(generation.get("text"), "") == config.model.reasoning_start_token + and config.model.reasoning + ): + # Update reasoning chunk flag + is_reasoning_chunk = True + # And skip this token + continue + if ( + unwrap(generation.get("text"), "") == config.model.reasoning_end_token + and config.model.reasoning + ): + # Update reasoning chunk flag + is_reasoning_chunk = False + # And skip this token + continue + response = _create_stream_chunk( - request.state.id, generation, model_path.name + request.state.id, + generation, + model_path.name, + is_reasoning_chunk=is_reasoning_chunk, ) yield response.model_dump_json()