diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3829305b..22537c93 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -214,26 +214,33 @@ async def get_topic_summary( ) -@router.post("/query", responses=query_response) -@authorize(Action.QUERY) -async def query_endpoint_handler( # pylint: disable=R0914 +async def query_endpoint_handler_base( # pylint: disable=R0914 request: Request, query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(get_auth_dependency())], - mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), + mcp_headers: dict[str, dict[str, str]], + retrieve_response_func: Any, + get_topic_summary_func: Any, ) -> QueryResponse: """ - Handle request to the /query endpoint. + Handle query endpoints (shared by Agent API and Responses API). - Processes a POST request to the /query endpoint, forwarding the - user's query to a selected Llama Stack LLM or agent and - returning the generated response. + Processes a POST request to a query endpoint, forwarding the + user's query to a selected Llama Stack LLM and returning the generated response. Validates configuration and authentication, selects the appropriate model and provider, retrieves the LLM response, updates metrics, and optionally stores a transcript of the interaction. Handles connection errors to the Llama Stack service by returning an HTTP 500 error. + Args: + request: The FastAPI request object + query_request: The query request containing the user's question + auth: Authentication tuple from dependency + mcp_headers: MCP headers from dependency + retrieve_response_func: The retrieve_response function to use (Agent or Responses API) + get_topic_summary_func: The get_topic_summary function to use (Agent or Responses API) + Returns: QueryResponse: Contains the conversation ID and the LLM-generated response. """ @@ -288,7 +295,7 @@ async def query_endpoint_handler( # pylint: disable=R0914 ), ) summary, conversation_id, referenced_documents, token_usage = ( - await retrieve_response( + await retrieve_response_func( client, llama_stack_model_id, query_request, @@ -305,8 +312,8 @@ async def query_endpoint_handler( # pylint: disable=R0914 session.query(UserConversation).filter_by(id=conversation_id).first() ) if not existing_conversation: - topic_summary = await get_topic_summary( - query_request.query, client, model_id + topic_summary = await get_topic_summary_func( + query_request.query, client, llama_stack_model_id ) # Convert RAG chunks to dictionary format once for reuse logger.info("Processing RAG chunks...") @@ -416,6 +423,33 @@ async def query_endpoint_handler( # pylint: disable=R0914 ) from e +@router.post("/query", responses=query_response) +@authorize(Action.QUERY) +async def query_endpoint_handler( + request: Request, + query_request: QueryRequest, + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], + mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), +) -> QueryResponse: + """ + Handle request to the /query endpoint using Agent API. + + This is a wrapper around query_endpoint_handler_base that provides + the Agent API specific retrieve_response and get_topic_summary functions. + + Returns: + QueryResponse: Contains the conversation ID and the LLM-generated response. + """ + return await query_endpoint_handler_base( + request=request, + query_request=query_request, + auth=auth, + mcp_headers=mcp_headers, + retrieve_response_func=retrieve_response, + get_topic_summary_func=get_topic_summary, + ) + + def select_model_and_provider_id( models: ModelListResponse, model_id: str | None, provider_id: str | None ) -> tuple[str, str, str]: diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py new file mode 100644 index 00000000..0365d9fd --- /dev/null +++ b/src/app/endpoints/query_v2.py @@ -0,0 +1,630 @@ +"""Handler for REST API call to provide answer to query using Response API.""" + +import logging +from typing import Annotated, Any, cast + +from llama_stack_client import AsyncLlamaStackClient # type: ignore +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseObject, +) + +from fastapi import APIRouter, Request, Depends + +from app.endpoints.query import ( + query_endpoint_handler_base, + validate_attachments_metadata, +) +from constants import DEFAULT_RAG_TOOL +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from configuration import configuration +import metrics +from models.config import Action +from models.requests import QueryRequest +from models.responses import ( + ForbiddenResponse, + QueryResponse, + ReferencedDocument, + UnauthorizedResponse, +) +from utils.endpoints import ( + get_system_prompt, + get_topic_summary_system_prompt, +) +from utils.mcp_headers import mcp_headers_dependency +from utils.token_counter import TokenCounter +from utils.types import TurnSummary, ToolCallSummary + +logger = logging.getLogger("app.endpoints.handlers") +router = APIRouter(tags=["query_v2"]) + +query_v2_response: dict[int | str, dict[str, Any]] = { + 200: { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "response": "LLM answer", + "referenced_documents": [ + { + "doc_url": "https://docs.openshift.com/" + "container-platform/4.15/operators/olm/index.html", + "doc_title": "Operator Lifecycle Manager (OLM)", + } + ], + }, + 400: { + "description": "Missing or invalid credentials provided by client", + "model": UnauthorizedResponse, + }, + 403: { + "description": "User is not authorized", + "model": ForbiddenResponse, + }, + 500: { + "detail": { + "response": "Unable to connect to Llama Stack", + "cause": "Connection error.", + } + }, +} + + +def _extract_text_from_response_output_item(output_item: Any) -> str: + """Extract assistant message text from a Responses API output item.""" + if getattr(output_item, "type", None) != "message": + return "" + if getattr(output_item, "role", None) != "assistant": + return "" + + content = getattr(output_item, "content", None) + if isinstance(content, str): + return content + + text_fragments: list[str] = [] + if isinstance(content, list): + for part in content: + if isinstance(part, str): + text_fragments.append(part) + continue + text_value = getattr(part, "text", None) + if text_value: + text_fragments.append(text_value) + continue + refusal = getattr(part, "refusal", None) + if refusal: + text_fragments.append(refusal) + continue + if isinstance(part, dict): + dict_text = part.get("text") or part.get("refusal") + if dict_text: + text_fragments.append(str(dict_text)) + + return "".join(text_fragments) + + +def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches + output_item: Any, +) -> ToolCallSummary | None: + """Translate applicable Responses API tool outputs into ``ToolCallSummary`` records. + + The OpenAI ``response.output`` array may contain any ``OpenAIResponseOutput`` variant: + ``message``, ``function_call``, ``file_search_call``, ``web_search_call``, ``mcp_call``, + ``mcp_list_tools``, or ``mcp_approval_request``. The OpenAI Spec supports more types + but as llamastack does not support them yet they are not considered here. + """ + item_type = getattr(output_item, "type", None) + + if item_type == "function_call": + parsed_arguments = getattr(output_item, "arguments", "") + status = getattr(output_item, "status", None) + if status: + if isinstance(parsed_arguments, dict): + args: Any = {**parsed_arguments, "status": status} + else: + args = {"arguments": parsed_arguments, "status": status} + else: + args = parsed_arguments + + call_id = getattr(output_item, "id", None) or getattr( + output_item, "call_id", None + ) + return ToolCallSummary( + id=str(call_id), + name=getattr(output_item, "name", "function_call"), + args=args, + response=None, + ) + + if item_type == "file_search_call": + args = { + "queries": list(getattr(output_item, "queries", [])), + "status": getattr(output_item, "status", None), + } + results = getattr(output_item, "results", None) + response_payload: Any | None = None + if results is not None: + # Store only the essential result metadata to avoid large payloads + response_payload = { + "results": [ + { + "file_id": ( + getattr(result, "file_id", None) + if not isinstance(result, dict) + else result.get("file_id") + ), + "filename": ( + getattr(result, "filename", None) + if not isinstance(result, dict) + else result.get("filename") + ), + "score": ( + getattr(result, "score", None) + if not isinstance(result, dict) + else result.get("score") + ), + } + for result in results + ] + } + return ToolCallSummary( + id=str(getattr(output_item, "id")), + name=DEFAULT_RAG_TOOL, + args=args, + response=response_payload, + ) + + if item_type == "web_search_call": + args = {"status": getattr(output_item, "status", None)} + return ToolCallSummary( + id=str(getattr(output_item, "id")), + name="web_search", + args=args, + response=None, + ) + + if item_type == "mcp_call": + parsed_arguments = getattr(output_item, "arguments", "") + args = {"arguments": parsed_arguments} + server_label = getattr(output_item, "server_label", None) + if server_label: + args["server_label"] = server_label + error = getattr(output_item, "error", None) + if error: + args["error"] = error + + return ToolCallSummary( + id=str(getattr(output_item, "id")), + name=getattr(output_item, "name", "mcp_call"), + args=args, + response=getattr(output_item, "output", None), + ) + + if item_type == "mcp_list_tools": + tool_names: list[str] = [] + for tool in getattr(output_item, "tools", []): + if hasattr(tool, "name"): + tool_names.append(str(getattr(tool, "name"))) + elif isinstance(tool, dict) and tool.get("name"): + tool_names.append(str(tool.get("name"))) + args = { + "server_label": getattr(output_item, "server_label", None), + "tools": tool_names, + } + return ToolCallSummary( + id=str(getattr(output_item, "id")), + name="mcp_list_tools", + args=args, + response=None, + ) + + if item_type == "mcp_approval_request": + parsed_arguments = getattr(output_item, "arguments", "") + args = {"arguments": parsed_arguments} + server_label = getattr(output_item, "server_label", None) + if server_label: + args["server_label"] = server_label + return ToolCallSummary( + id=str(getattr(output_item, "id")), + name=getattr(output_item, "name", "mcp_approval_request"), + args=args, + response=None, + ) + + return None + + +async def get_topic_summary( # pylint: disable=too-many-nested-blocks + question: str, client: AsyncLlamaStackClient, model_id: str +) -> str: + """ + Get a topic summary for a question using Responses API. + + This is the Responses API version of get_topic_summary, which uses + client.responses.create() instead of the Agent API. + + Args: + question: The question to generate a topic summary for + client: The AsyncLlamaStackClient to use for the request + model_id: The llama stack model ID (full format: provider/model) + + Returns: + str: The topic summary for the question + """ + topic_summary_system_prompt = get_topic_summary_system_prompt(configuration) + + try: + # Use Responses API to generate topic summary + response = await client.responses.create( + input=question, + model=model_id, + instructions=topic_summary_system_prompt, + stream=False, + store=False, # Don't store topic summary requests + ) + response = cast(OpenAIResponseObject, response) + + # Extract text from response output + summary_text = "".join( + _extract_text_from_response_output_item(output_item) + for output_item in response.output + ) + + return summary_text.strip() if summary_text else "" + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Failed to generate topic summary: %s", e) + return "" # Return empty string on failure + + +@router.post("/query", responses=query_v2_response) +@authorize(Action.QUERY) +async def query_endpoint_handler_v2( + request: Request, + query_request: QueryRequest, + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], + mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), +) -> QueryResponse: + """ + Handle request to the /query endpoint using Responses API. + + This is a wrapper around query_endpoint_handler_base that provides + the Responses API specific retrieve_response and get_topic_summary functions. + + Returns: + QueryResponse: Contains the conversation ID and the LLM-generated response. + """ + return await query_endpoint_handler_base( + request=request, + query_request=query_request, + auth=auth, + mcp_headers=mcp_headers, + retrieve_response_func=retrieve_response, + get_topic_summary_func=get_topic_summary, + ) + + +async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments + client: AsyncLlamaStackClient, + model_id: str, + query_request: QueryRequest, + token: str, + mcp_headers: dict[str, dict[str, str]] | None = None, + *, + provider_id: str = "", +) -> tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]: + """ + Retrieve response from LLMs and agents. + + Retrieves a response from the Llama Stack LLM or agent for a + given query, handling shield configuration, tool usage, and + attachment validation. + + This function configures system prompts and toolgroups + (including RAG and MCP integration) as needed based on + the query request and system configuration. It + validates attachments, manages conversation and session + context, and processes MCP headers for multi-component + processing. Corresponding metrics are updated. + + Parameters: + client (AsyncLlamaStackClient): The AsyncLlamaStackClient to use for the request. + model_id (str): The identifier of the LLM model to use. + query_request (QueryRequest): The user's query and associated metadata. + token (str): The authentication token for authorization. + mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. + provider_id (str): The identifier of the LLM provider to use. + + Returns: + tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content + and the conversation ID, the list of parsed referenced documents, + and token usage information. + """ + # TODO(ltomasbo): implement shields support once available in Responses API + logger.info("Shields are not yet supported in Responses API. Disabling safety") + + # use system prompt from request or default one + system_prompt = get_system_prompt(query_request, configuration) + logger.debug("Using system prompt: %s", system_prompt) + + # TODO(lucasagomes): redact attachments content before sending to LLM + # if attachments are provided, validate them + if query_request.attachments: + validate_attachments_metadata(query_request.attachments) + + # Prepare tools for responses API + toolgroups: list[dict[str, Any]] | None = None + if not query_request.no_tools: + toolgroups = [] + # Get vector stores for RAG tools + vector_store_ids = [ + vector_store.id for vector_store in (await client.vector_stores.list()).data + ] + + # Add RAG tools if vector stores are available + rag_tools = get_rag_tools(vector_store_ids) + if rag_tools: + toolgroups.extend(rag_tools) + + # Add MCP server tools + mcp_tools = get_mcp_tools(configuration.mcp_servers, token, mcp_headers) + if mcp_tools: + toolgroups.extend(mcp_tools) + logger.debug( + "Configured %d MCP tools: %s", + len(mcp_tools), + [tool.get("server_label", "unknown") for tool in mcp_tools], + ) + # Convert empty list to None for consistency with existing behavior + if not toolgroups: + toolgroups = None + + # Prepare input for Responses API + # Convert attachments to text and concatenate with query + input_text = query_request.query + if query_request.attachments: + for attachment in query_request.attachments: + # Append attachment content with type label + input_text += ( + f"\n\n[Attachment: {attachment.attachment_type}]\n{attachment.content}" + ) + + # Create OpenAI response using responses API + create_kwargs: dict[str, Any] = { + "input": input_text, + "model": model_id, + "instructions": system_prompt, + "tools": cast(Any, toolgroups), + "stream": False, + "store": True, + } + if query_request.conversation_id: + create_kwargs["previous_response_id"] = query_request.conversation_id + + response = await client.responses.create(**create_kwargs) + response = cast(OpenAIResponseObject, response) + + logger.debug( + "Received response with ID: %s, output items: %d", + response.id, + len(response.output), + ) + + # Return the response ID - client can use it for chaining if desired + conversation_id = response.id + + # Process OpenAI response format + llm_response = "" + tool_calls: list[ToolCallSummary] = [] + + for output_item in response.output: + message_text = _extract_text_from_response_output_item(output_item) + if message_text: + llm_response += message_text + + tool_summary = _build_tool_call_summary(output_item) + if tool_summary: + tool_calls.append(tool_summary) + + logger.info( + "Response processing complete - Tool calls: %d, Response length: %d chars", + len(tool_calls), + len(llm_response), + ) + + summary = TurnSummary( + llm_response=llm_response, + tool_calls=tool_calls, + ) + + # Extract referenced documents and token usage from Responses API response + referenced_documents = parse_referenced_documents_from_responses_api(response) + model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id + token_usage = extract_token_usage_from_responses_api( + response, model_label, provider_id, system_prompt + ) + + if not summary.llm_response: + logger.warning( + "Response lacks content (conversation_id=%s)", + conversation_id, + ) + return (summary, conversation_id, referenced_documents, token_usage) + + +def parse_referenced_documents_from_responses_api( + response: OpenAIResponseObject, # pylint: disable=unused-argument +) -> list[ReferencedDocument]: + """ + Parse referenced documents from OpenAI Responses API response. + + Args: + response: The OpenAI Response API response object + + Returns: + list[ReferencedDocument]: List of referenced documents with doc_url and doc_title + """ + # TODO(ltomasbo): need to parse source documents from Responses API response. + # The Responses API has a different structure than Agent API for referenced documents. + # Need to extract from: + # - OpenAIResponseOutputMessageFileSearchToolCall.results + # - OpenAIResponseAnnotationCitation in message content + # - OpenAIResponseAnnotationFileCitation in message content + return [] + + +def extract_token_usage_from_responses_api( + response: OpenAIResponseObject, + model: str, + provider: str, + system_prompt: str = "", # pylint: disable=unused-argument +) -> TokenCounter: + """ + Extract token usage from OpenAI Responses API response and update metrics. + + This function extracts token usage information from the Responses API response + object and updates Prometheus metrics. If usage information is not available, + it returns zero values without estimation. + + Note: When llama stack internally uses chat_completions, the usage field may be + empty or a dict. This is expected and will be populated in future llama stack versions. + + Args: + response: The OpenAI Response API response object + model: The model identifier for metrics labeling + provider: The provider identifier for metrics labeling + system_prompt: The system prompt used (unused, kept for compatibility) + + Returns: + TokenCounter: Token usage information with input_tokens and output_tokens + """ + token_counter = TokenCounter() + token_counter.llm_calls = 1 + + # Extract usage from the response if available + # Note: usage attribute exists at runtime but may not be in type definitions + usage = getattr(response, "usage", None) + if usage: + try: + # Handle both dict and object cases due to llama_stack inconsistency: + # - When llama_stack converts to chat_completions internally, usage is a dict + # - When using proper Responses API, usage should be an object + # TODO: Remove dict handling once llama_stack standardizes on object type # pylint: disable=fixme + if isinstance(usage, dict): + input_tokens = usage.get("input_tokens", 0) + output_tokens = usage.get("output_tokens", 0) + else: + # Object with attributes (expected final behavior) + input_tokens = getattr(usage, "input_tokens", 0) + output_tokens = getattr(usage, "output_tokens", 0) + # Only set if we got valid values + if input_tokens or output_tokens: + token_counter.input_tokens = input_tokens or 0 + token_counter.output_tokens = output_tokens or 0 + + logger.debug( + "Extracted token usage from Responses API: input=%d, output=%d", + token_counter.input_tokens, + token_counter.output_tokens, + ) + + # Update Prometheus metrics only when we have actual usage data + try: + metrics.llm_token_sent_total.labels(provider, model).inc( + token_counter.input_tokens + ) + metrics.llm_token_received_total.labels(provider, model).inc( + token_counter.output_tokens + ) + except (AttributeError, TypeError, ValueError) as e: + logger.warning("Failed to update token metrics: %s", e) + _increment_llm_call_metric(provider, model) + else: + logger.debug( + "Usage object exists but tokens are 0 or None, treating as no usage info" + ) + # Still increment the call counter + _increment_llm_call_metric(provider, model) + except (AttributeError, KeyError, TypeError) as e: + logger.warning( + "Failed to extract token usage from response.usage: %s. Usage value: %s", + e, + usage, + ) + # Still increment the call counter + _increment_llm_call_metric(provider, model) + else: + # No usage information available - this is expected when llama stack + # internally converts to chat_completions + logger.debug( + "No usage information in Responses API response, token counts will be 0" + ) + # token_counter already initialized with 0 values + # Still increment the call counter + _increment_llm_call_metric(provider, model) + + return token_counter + + +def _increment_llm_call_metric(provider: str, model: str) -> None: + """Safely increment LLM call metric.""" + try: + metrics.llm_calls_total.labels(provider, model).inc() + except (AttributeError, TypeError, ValueError) as e: + logger.warning("Failed to update LLM call metric: %s", e) + + +def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None: + """ + Convert vector store IDs to tools format for Responses API. + + Args: + vector_store_ids: List of vector store identifiers + + Returns: + list[dict[str, Any]] | None: List containing file_search tool configuration, + or None if no vector stores provided + """ + if not vector_store_ids: + return None + + return [ + { + "type": "file_search", + "vector_store_ids": vector_store_ids, + "max_num_results": 10, + } + ] + + +def get_mcp_tools( + mcp_servers: list, + token: str | None = None, + mcp_headers: dict[str, dict[str, str]] | None = None, +) -> list[dict[str, Any]]: + """ + Convert MCP servers to tools format for Responses API. + + Args: + mcp_servers: List of MCP server configurations + token: Optional authentication token for MCP server authorization + mcp_headers: Optional per-request headers for MCP servers, keyed by server URL + + Returns: + list[dict[str, Any]]: List of MCP tool definitions with server + details and optional auth headers + """ + tools = [] + for mcp_server in mcp_servers: + tool_def = { + "type": "mcp", + "server_label": mcp_server.name, + "server_url": mcp_server.url, + "require_approval": "never", + } + + # Add authentication if headers or token provided (Response API format) + headers = (mcp_headers or {}).get(mcp_server.url) + if headers: + tool_def["headers"] = headers + elif token: + tool_def["headers"] = {"Authorization": f"Bearer {token}"} + tools.append(tool_def) + return tools diff --git a/src/app/routers.py b/src/app/routers.py index 66c70766..1ca7044c 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -18,6 +18,8 @@ conversations_v2, metrics, tools, + # V2 endpoints for Response API support + query_v2, ) @@ -28,6 +30,7 @@ def include_routers(app: FastAPI) -> None: app: The `FastAPI` app instance. """ app.include_router(root.router) + app.include_router(info.router, prefix="/v1") app.include_router(models.router, prefix="/v1") app.include_router(tools.router, prefix="/v1") @@ -40,6 +43,9 @@ def include_routers(app: FastAPI) -> None: app.include_router(conversations.router, prefix="/v1") app.include_router(conversations_v2.router, prefix="/v2") + # V2 endpoints - Response API support + app.include_router(query_v2.router, prefix="/v2") + # road-core does not version these endpoints app.include_router(health.router) app.include_router(authorized.router) diff --git a/tests/unit/app/endpoints/test_query_v2.py b/tests/unit/app/endpoints/test_query_v2.py new file mode 100644 index 00000000..0790f151 --- /dev/null +++ b/tests/unit/app/endpoints/test_query_v2.py @@ -0,0 +1,423 @@ +# pylint: disable=redefined-outer-name, import-error +"""Unit tests for the /query (v2) REST API endpoint using Responses API.""" + +import pytest +from fastapi import HTTPException, status, Request + +from llama_stack_client import APIConnectionError + +from models.requests import QueryRequest, Attachment +from models.config import ModelContextProtocolServer + +from app.endpoints.query_v2 import ( + get_rag_tools, + get_mcp_tools, + retrieve_response, + query_endpoint_handler_v2, +) + + +@pytest.fixture +def dummy_request() -> Request: + """Create a dummy FastAPI Request object for testing.""" + req = Request(scope={"type": "http"}) + return req + + +def test_get_rag_tools(): + """Test get_rag_tools returns None for empty list and correct tool format for vector stores.""" + assert get_rag_tools([]) is None + + tools = get_rag_tools(["db1", "db2"]) + assert isinstance(tools, list) + assert tools[0]["type"] == "file_search" + assert tools[0]["vector_store_ids"] == ["db1", "db2"] + assert tools[0]["max_num_results"] == 10 + + +def test_get_mcp_tools_with_and_without_token(): + """Test get_mcp_tools generates correct tool definitions with and without auth tokens.""" + servers = [ + ModelContextProtocolServer(name="fs", url="http://localhost:3000"), + ModelContextProtocolServer(name="git", url="https://git.example.com/mcp"), + ] + + tools_no_token = get_mcp_tools(servers, token=None) + assert len(tools_no_token) == 2 + assert tools_no_token[0]["type"] == "mcp" + assert tools_no_token[0]["server_label"] == "fs" + assert tools_no_token[0]["server_url"] == "http://localhost:3000" + assert "headers" not in tools_no_token[0] + + tools_with_token = get_mcp_tools(servers, token="abc") + assert len(tools_with_token) == 2 + assert tools_with_token[1]["type"] == "mcp" + assert tools_with_token[1]["server_label"] == "git" + assert tools_with_token[1]["server_url"] == "https://git.example.com/mcp" + assert tools_with_token[1]["headers"] == {"Authorization": "Bearer abc"} + + +@pytest.mark.asyncio +async def test_retrieve_response_no_tools_bypasses_tools(mocker): + """Test that no_tools=True bypasses tool configuration and passes None to responses API.""" + mock_client = mocker.Mock() + # responses.create returns a synthetic OpenAI-like response + response_obj = mocker.Mock() + response_obj.id = "resp-1" + response_obj.output = [] + response_obj.usage = None # No usage info + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + # vector_stores.list should not matter when no_tools=True, but keep it valid + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + + # Ensure system prompt resolution does not require real config + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + qr = QueryRequest(query="hello", no_tools=True) + summary, conv_id, referenced_docs, token_usage = await retrieve_response( + mock_client, "model-x", qr, token="tkn" + ) + + assert conv_id == "resp-1" + assert summary.llm_response == "" + assert referenced_docs == [] + assert token_usage.input_tokens == 0 # No usage info, so 0 + assert token_usage.output_tokens == 0 + # tools must be passed as None + kwargs = mock_client.responses.create.call_args.kwargs + assert kwargs["tools"] is None + assert kwargs["model"] == "model-x" + assert kwargs["instructions"] == "PROMPT" + + +@pytest.mark.asyncio +async def test_retrieve_response_builds_rag_and_mcp_tools(mocker): + """Test that retrieve_response correctly builds RAG and MCP tools from configuration.""" + mock_client = mocker.Mock() + response_obj = mocker.Mock() + response_obj.id = "resp-2" + response_obj.output = [] + response_obj.usage = None + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [mocker.Mock(id="dbA")] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mock_cfg = mocker.Mock() + mock_cfg.mcp_servers = [ + ModelContextProtocolServer(name="fs", url="http://localhost:3000"), + ] + mocker.patch("app.endpoints.query_v2.configuration", mock_cfg) + + qr = QueryRequest(query="hello") + _summary, conv_id, referenced_docs, token_usage = await retrieve_response( + mock_client, "model-y", qr, token="mytoken" + ) + + assert conv_id == "resp-2" + assert referenced_docs == [] + assert token_usage.input_tokens == 0 # No usage info, so 0 + assert token_usage.output_tokens == 0 + + kwargs = mock_client.responses.create.call_args.kwargs + tools = kwargs["tools"] + assert isinstance(tools, list) + # Expect one file_search and one mcp tool + tool_types = {t.get("type") for t in tools} + assert tool_types == {"file_search", "mcp"} + file_search = next(t for t in tools if t["type"] == "file_search") + assert file_search["vector_store_ids"] == ["dbA"] + mcp_tool = next(t for t in tools if t["type"] == "mcp") + assert mcp_tool["server_label"] == "fs" + assert mcp_tool["headers"] == {"Authorization": "Bearer mytoken"} + + +@pytest.mark.asyncio +async def test_retrieve_response_parses_output_and_tool_calls(mocker): + """Test that retrieve_response correctly parses output content and tool calls from response.""" + mock_client = mocker.Mock() + + # Build output with content variants and tool calls + output_item_1 = mocker.Mock() + output_item_1.type = "message" + output_item_1.role = "assistant" + output_item_1.content = [mocker.Mock(text="Hello "), mocker.Mock(text="world")] + + output_item_2 = mocker.Mock() + output_item_2.type = "message" + output_item_2.role = "assistant" + output_item_2.content = "!" + + # Tool call as a separate output item (Responses API format) + tool_call_item = mocker.Mock() + tool_call_item.type = "function_call" + tool_call_item.id = "tc-1" + tool_call_item.name = "do_something" + tool_call_item.arguments = {"x": 1} + tool_call_item.status = None # Explicitly set to avoid Mock auto-creation + + response_obj = mocker.Mock() + response_obj.id = "resp-3" + response_obj.output = [output_item_1, output_item_2, tool_call_item] + response_obj.usage = None + + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + qr = QueryRequest(query="hello") + summary, conv_id, referenced_docs, token_usage = await retrieve_response( + mock_client, "model-z", qr, token="tkn" + ) + + assert conv_id == "resp-3" + assert summary.llm_response == "Hello world!" + assert len(summary.tool_calls) == 1 + assert summary.tool_calls[0].id == "tc-1" + assert summary.tool_calls[0].name == "do_something" + assert summary.tool_calls[0].args == {"x": 1} + assert referenced_docs == [] + assert token_usage.input_tokens == 0 # No usage info, so 0 + assert token_usage.output_tokens == 0 + + +@pytest.mark.asyncio +async def test_retrieve_response_with_usage_info(mocker): + """Test that token usage is extracted when provided by the API as an object.""" + mock_client = mocker.Mock() + + output_item = mocker.Mock() + output_item.type = "message" + output_item.role = "assistant" + output_item.content = "Test response" + output_item.tool_calls = [] + + # Mock usage information as object + mock_usage = mocker.Mock() + mock_usage.input_tokens = 150 + mock_usage.output_tokens = 75 + + response_obj = mocker.Mock() + response_obj.id = "resp-with-usage" + response_obj.output = [output_item] + response_obj.usage = mock_usage + + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + qr = QueryRequest(query="hello") + summary, conv_id, _referenced_docs, token_usage = await retrieve_response( + mock_client, "model-usage", qr, token="tkn", provider_id="test-provider" + ) + + assert conv_id == "resp-with-usage" + assert summary.llm_response == "Test response" + assert token_usage.input_tokens == 150 + assert token_usage.output_tokens == 75 + assert token_usage.llm_calls == 1 + + +@pytest.mark.asyncio +async def test_retrieve_response_with_usage_dict(mocker): + """Test that token usage is extracted when provided by the API as a dict.""" + mock_client = mocker.Mock() + + output_item = mocker.Mock() + output_item.type = "message" + output_item.role = "assistant" + output_item.content = "Test response dict" + output_item.tool_calls = [] + + # Mock usage information as dict (like llama stack does) + response_obj = mocker.Mock() + response_obj.id = "resp-with-usage-dict" + response_obj.output = [output_item] + response_obj.usage = {"input_tokens": 200, "output_tokens": 100} + + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + qr = QueryRequest(query="hello") + summary, conv_id, _referenced_docs, token_usage = await retrieve_response( + mock_client, "model-usage-dict", qr, token="tkn", provider_id="test-provider" + ) + + assert conv_id == "resp-with-usage-dict" + assert summary.llm_response == "Test response dict" + assert token_usage.input_tokens == 200 + assert token_usage.output_tokens == 100 + assert token_usage.llm_calls == 1 + + +@pytest.mark.asyncio +async def test_retrieve_response_with_empty_usage_dict(mocker): + """Test that empty usage dict is handled gracefully.""" + mock_client = mocker.Mock() + + output_item = mocker.Mock() + output_item.type = "message" + output_item.role = "assistant" + output_item.content = "Test response empty usage" + output_item.tool_calls = [] + + # Mock usage information as empty dict (tokens are 0 or missing) + response_obj = mocker.Mock() + response_obj.id = "resp-empty-usage" + response_obj.output = [output_item] + response_obj.usage = {} # Empty dict + + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + qr = QueryRequest(query="hello") + summary, conv_id, _referenced_docs, token_usage = await retrieve_response( + mock_client, "model-empty-usage", qr, token="tkn", provider_id="test-provider" + ) + + assert conv_id == "resp-empty-usage" + assert summary.llm_response == "Test response empty usage" + assert token_usage.input_tokens == 0 + assert token_usage.output_tokens == 0 + assert token_usage.llm_calls == 1 # Always 1, even when no token usage data + + +@pytest.mark.asyncio +async def test_retrieve_response_validates_attachments(mocker): + """Test that retrieve_response validates attachments and includes them in the input string.""" + mock_client = mocker.Mock() + response_obj = mocker.Mock() + response_obj.id = "resp-4" + response_obj.output = [] + response_obj.usage = None + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + validate_spy = mocker.patch( + "app.endpoints.query_v2.validate_attachments_metadata", return_value=None + ) + + attachments = [ + Attachment(attachment_type="log", content_type="text/plain", content="x"), + ] + + qr = QueryRequest(query="hello", attachments=attachments) + _summary, _cid, _ref_docs, _token_usage = await retrieve_response( + mock_client, "model-a", qr, token="tkn" + ) + + validate_spy.assert_called_once() + # Verify that attachments are included in the input + kwargs = mock_client.responses.create.call_args.kwargs + assert "input" in kwargs + # Input should be a string containing both query and attachment + assert isinstance(kwargs["input"], str) + assert "hello" in kwargs["input"] + assert "[Attachment: log]" in kwargs["input"] + assert "x" in kwargs["input"] + + +@pytest.mark.asyncio +async def test_query_endpoint_handler_v2_success(mocker, dummy_request): + """Test successful query endpoint handler execution with proper response structure.""" + # Mock configuration to avoid configuration not loaded errors + mock_config = mocker.Mock() + mock_config.llama_stack_configuration = mocker.Mock() + mock_config.quota_limiters = [] + mocker.patch("app.endpoints.query_v2.configuration", mock_config) + + mock_client = mocker.Mock() + mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()]) + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + mocker.patch("app.endpoints.query.evaluate_model_hints", return_value=(None, None)) + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("llama/m", "m", "p"), + ) + + summary = mocker.Mock(llm_response="ANSWER", tool_calls=[], rag_chunks=[]) + token_usage = mocker.Mock(input_tokens=10, output_tokens=20) + mocker.patch( + "app.endpoints.query_v2.retrieve_response", + return_value=(summary, "conv-1", [], token_usage), + ) + mocker.patch("app.endpoints.query_v2.get_topic_summary", return_value="Topic") + mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + mocker.patch("app.endpoints.query.persist_user_conversation_details") + mocker.patch("utils.endpoints.store_conversation_into_cache") + mocker.patch("app.endpoints.query.get_session") + + # Add missing mocks for quota functions + mocker.patch("utils.quota.check_tokens_available") + mocker.patch("utils.quota.consume_tokens") + mocker.patch("utils.quota.get_available_quotas", return_value={}) + + # Mock the request state + dummy_request.state.authorized_actions = [] + + res = await query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", False, "token-abc"), + mcp_headers={}, + ) + + assert res.conversation_id == "conv-1" + assert res.response == "ANSWER" + + +@pytest.mark.asyncio +async def test_query_endpoint_handler_v2_api_connection_error(mocker, dummy_request): + """Test that query endpoint handler properly handles and reports API connection errors.""" + # Mock configuration to avoid configuration not loaded errors + mock_config = mocker.Mock() + mock_config.llama_stack_configuration = mocker.Mock() + mocker.patch("app.endpoints.query_v2.configuration", mock_config) + + def _raise(*_args, **_kwargs): + raise APIConnectionError(request=None) + + mocker.patch("client.AsyncLlamaStackClientHolder.get_client", side_effect=_raise) + + fail_metric = mocker.patch("metrics.llm_calls_failures_total") + + with pytest.raises(HTTPException) as exc: + await query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", False, "token-abc"), + mcp_headers={}, + ) + + assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Unable to connect to Llama Stack" in str(exc.value.detail) + fail_metric.inc.assert_called_once() diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index e466fca4..a6c23a2b 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -15,6 +15,7 @@ shields, providers, query, + query_v2, health, config, feedback, @@ -64,7 +65,7 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 15 + assert len(app.routers) == 16 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() @@ -72,6 +73,7 @@ def test_include_routers() -> None: assert shields.router in app.get_routers() assert providers.router in app.get_routers() assert query.router in app.get_routers() + assert query_v2.router in app.get_routers() assert streaming_query.router in app.get_routers() assert config.router in app.get_routers() assert feedback.router in app.get_routers() @@ -88,7 +90,7 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 15 + assert len(app.routers) == 16 assert app.get_router_prefix(root.router) == "" assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" @@ -97,6 +99,7 @@ def test_check_prefixes() -> None: assert app.get_router_prefix(providers.router) == "/v1" assert app.get_router_prefix(query.router) == "/v1" assert app.get_router_prefix(streaming_query.router) == "/v1" + assert app.get_router_prefix(query_v2.router) == "/v2" assert app.get_router_prefix(config.router) == "/v1" assert app.get_router_prefix(feedback.router) == "/v1" assert app.get_router_prefix(health.router) == ""