diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index 0365d9fd..a4d2a2f3 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -18,7 +18,7 @@ from authentication import get_auth_dependency from authentication.interface import AuthTuple from authorization.middleware import authorize -from configuration import configuration +from configuration import AppConfig, configuration import metrics from models.config import Action from models.requests import QueryRequest @@ -350,31 +350,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche validate_attachments_metadata(query_request.attachments) # Prepare tools for responses API - toolgroups: list[dict[str, Any]] | None = None - if not query_request.no_tools: - toolgroups = [] - # Get vector stores for RAG tools - vector_store_ids = [ - vector_store.id for vector_store in (await client.vector_stores.list()).data - ] - - # Add RAG tools if vector stores are available - rag_tools = get_rag_tools(vector_store_ids) - if rag_tools: - toolgroups.extend(rag_tools) - - # Add MCP server tools - mcp_tools = get_mcp_tools(configuration.mcp_servers, token, mcp_headers) - if mcp_tools: - toolgroups.extend(mcp_tools) - logger.debug( - "Configured %d MCP tools: %s", - len(mcp_tools), - [tool.get("server_label", "unknown") for tool in mcp_tools], - ) - # Convert empty list to None for consistency with existing behavior - if not toolgroups: - toolgroups = None + toolgroups = await prepare_tools_for_responses_api( + client, query_request, token, configuration, mcp_headers + ) # Prepare input for Responses API # Convert attachments to text and concatenate with query @@ -620,11 +598,71 @@ def get_mcp_tools( "require_approval": "never", } - # Add authentication if headers or token provided (Response API format) - headers = (mcp_headers or {}).get(mcp_server.url) - if headers: + # Build headers: start with token auth, then merge in per-server headers + if token or mcp_headers: + headers = {} + # Add token-based auth if available + if token: + headers["Authorization"] = f"Bearer {token}" + # Merge in per-server headers (can override Authorization if needed) + server_headers = (mcp_headers or {}).get(mcp_server.url) + if server_headers: + headers.update(server_headers) tool_def["headers"] = headers - elif token: - tool_def["headers"] = {"Authorization": f"Bearer {token}"} + tools.append(tool_def) return tools + + +async def prepare_tools_for_responses_api( + client: AsyncLlamaStackClient, + query_request: QueryRequest, + token: str, + config: AppConfig, + mcp_headers: dict[str, dict[str, str]] | None = None, +) -> list[dict[str, Any]] | None: + """ + Prepare tools for Responses API including RAG and MCP tools. + + This function retrieves vector stores and combines them with MCP + server tools to create a unified toolgroups list for the Responses API. + + Args: + client: The Llama Stack client instance + query_request: The user's query request + token: Authentication token for MCP tools + config: Configuration object containing MCP server settings + mcp_headers: Per-request headers for MCP servers + + Returns: + list[dict[str, Any]] | None: List of tool configurations for the + Responses API, or None if no_tools is True or no tools are available + """ + if query_request.no_tools: + return None + + toolgroups = [] + # Get vector stores for RAG tools + vector_store_ids = [ + vector_store.id for vector_store in (await client.vector_stores.list()).data + ] + + # Add RAG tools if vector stores are available + rag_tools = get_rag_tools(vector_store_ids) + if rag_tools: + toolgroups.extend(rag_tools) + + # Add MCP server tools + mcp_tools = get_mcp_tools(config.mcp_servers, token, mcp_headers) + if mcp_tools: + toolgroups.extend(mcp_tools) + logger.debug( + "Configured %d MCP tools: %s", + len(mcp_tools), + [tool.get("server_label", "unknown") for tool in mcp_tools], + ) + # Convert empty list to None for consistency with existing behavior + if not toolgroups: + return None + + return toolgroups diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d4ad3088..c324016b 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -5,6 +5,7 @@ import logging import re import uuid +from collections.abc import Callable from datetime import UTC, datetime from typing import Annotated, Any, AsyncGenerator, AsyncIterator, Iterator, cast @@ -23,7 +24,6 @@ from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from llama_stack_client.types.agents.turn_create_params import Document -from app.database import get_session from app.endpoints.query import ( get_rag_toolgroups, is_input_shield, @@ -44,18 +44,17 @@ from constants import DEFAULT_RAG_TOOL, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT import metrics from metrics.utils import update_llm_token_count_from_turn -from models.cache_entry import CacheEntry from models.config import Action +from models.context import ResponseGeneratorContext from models.database.conversations import UserConversation from models.requests import QueryRequest from models.responses import ForbiddenResponse, UnauthorizedResponse from utils.endpoints import ( check_configuration_loaded, - create_referenced_documents_with_metadata, + cleanup_after_streaming, create_rag_chunks_dict, get_agent, get_system_prompt, - store_conversation_into_cache, validate_model_provider_override, ) from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency @@ -695,31 +694,137 @@ def _handle_heartbeat_event( ) -@router.post("/streaming_query", responses=streaming_query_responses) -@authorize(Action.STREAMING_QUERY) -async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals,too-many-statements +def create_agent_response_generator( # pylint: disable=too-many-locals + context: ResponseGeneratorContext, +) -> Any: + """ + Create a response generator function for Agent API streaming. + + This factory function returns an async generator that processes streaming + responses from the Agent API and yields Server-Sent Events (SSE). + + Args: + context: Context object containing all necessary parameters for response generation + + Returns: + An async generator function that yields SSE-formatted strings + """ + + async def response_generator( + turn_response: AsyncIterator[AgentTurnResponseStreamChunk], + ) -> AsyncIterator[str]: + """ + Generate SSE formatted streaming response. + + Asynchronously generates a stream of Server-Sent Events + (SSE) representing incremental responses from a + language model turn. + + Yields start, token, tool call, turn completion, and + end events as SSE-formatted strings. Collects the + complete response for transcript storage if enabled. + """ + chunk_id = 0 + summary = TurnSummary(llm_response="No response from the model", tool_calls=[]) + + # Determine media type for response formatting + media_type = context.query_request.media_type or MEDIA_TYPE_JSON + + # Send start event at the beginning of the stream + yield stream_start_event(context.conversation_id) + + latest_turn: Any | None = None + + async for chunk in turn_response: + if chunk.event is None: + continue + p = chunk.event.payload + if p.event_type == "turn_complete": + summary.llm_response = interleaved_content_as_str( + p.turn.output_message.content + ) + latest_turn = p.turn + system_prompt = get_system_prompt(context.query_request, configuration) + try: + update_llm_token_count_from_turn( + p.turn, context.model_id, context.provider_id, system_prompt + ) + except Exception: # pylint: disable=broad-except + logger.exception("Failed to update token usage metrics") + elif p.event_type == "step_complete": + if p.step_details.step_type == "tool_execution": + summary.append_tool_calls_from_llama(p.step_details) + + for event in stream_build_event( + chunk, + chunk_id, + context.metadata_map, + media_type, + context.conversation_id, + ): + chunk_id += 1 + yield event + + # Extract token usage from the turn + token_usage = ( + extract_token_usage_from_turn(latest_turn) + if latest_turn is not None + else TokenCounter() + ) + + yield stream_end_event(context.metadata_map, summary, token_usage, media_type) + + # Perform cleanup tasks (database and cache operations) + await cleanup_after_streaming( + user_id=context.user_id, + conversation_id=context.conversation_id, + model_id=context.model_id, + provider_id=context.provider_id, + llama_stack_model_id=context.llama_stack_model_id, + query_request=context.query_request, + summary=summary, + metadata_map=context.metadata_map, + started_at=context.started_at, + client=context.client, + config=configuration, + skip_userid_check=context.skip_userid_check, + get_topic_summary_func=get_topic_summary, + is_transcripts_enabled_func=is_transcripts_enabled, + store_transcript_func=store_transcript, + persist_user_conversation_details_func=persist_user_conversation_details, + rag_chunks=create_rag_chunks_dict(summary), + ) + + return response_generator + + +async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-locals,too-many-statements,too-many-arguments,too-many-positional-arguments request: Request, query_request: QueryRequest, - auth: Annotated[AuthTuple, Depends(get_auth_dependency())], - mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), + auth: AuthTuple, + mcp_headers: dict[str, dict[str, str]], + retrieve_response_func: Callable[..., Any], + create_response_generator_func: Callable[..., Any], ) -> StreamingResponse: """ - Handle request to the /streaming_query endpoint. + Handle streaming query endpoints with common logic. - This endpoint receives a query request, authenticates the user, - selects the appropriate model and provider, and streams - incremental response events from the Llama Stack backend to the - client. Events include start, token updates, tool calls, turn - completions, errors, and end-of-stream metadata. Optionally - stores the conversation transcript if enabled in configuration. + This base handler contains all the common logic for streaming query endpoints + and accepts functions for API-specific behavior (Agent API vs Responses API). + + Args: + request: The FastAPI request object + query_request: The query request from the user + auth: Authentication tuple (user_id, username, skip_check, token) + mcp_headers: MCP headers for tool integrations + retrieve_response_func: Function to retrieve the streaming response + create_response_generator_func: Function factory that creates the response generator Returns: - StreamingResponse: An HTTP streaming response yielding - SSE-formatted events for the query lifecycle. + StreamingResponse: An HTTP streaming response yielding SSE-formatted events Raises: - HTTPException: Returns HTTP 500 if unable to connect to the - Llama Stack server. + HTTPException: Returns HTTP 500 if unable to connect to Llama Stack """ # Nothing interesting in the request _ = request @@ -764,7 +869,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals,t user_conversation=user_conversation, query_request=query_request ), ) - response, conversation_id = await retrieve_response( + response, conversation_id = await retrieve_response_func( client, llama_stack_model_id, query_request, @@ -773,133 +878,22 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals,t ) metadata_map: dict[str, dict[str, Any]] = {} - async def response_generator( - turn_response: AsyncIterator[AgentTurnResponseStreamChunk], - ) -> AsyncIterator[str]: - """ - Generate SSE formatted streaming response. - - Asynchronously generates a stream of Server-Sent Events - (SSE) representing incremental responses from a - language model turn. - - Yields start, token, tool call, turn completion, and - end events as SSE-formatted strings. Collects the - complete response for transcript storage if enabled. - """ - chunk_id = 0 - summary = TurnSummary( - llm_response="No response from the model", tool_calls=[] - ) - - # Determine media type for response formatting - media_type = query_request.media_type or MEDIA_TYPE_JSON - - # Send start event at the beginning of the stream - yield stream_start_event(conversation_id) - - latest_turn: Any | None = None - - async for chunk in turn_response: - if chunk.event is None: - continue - p = chunk.event.payload - if p.event_type == "turn_complete": - summary.llm_response = interleaved_content_as_str( - p.turn.output_message.content - ) - latest_turn = p.turn - system_prompt = get_system_prompt(query_request, configuration) - try: - update_llm_token_count_from_turn( - p.turn, model_id, provider_id, system_prompt - ) - except Exception: # pylint: disable=broad-except - logger.exception("Failed to update token usage metrics") - elif p.event_type == "step_complete": - if p.step_details.step_type == "tool_execution": - summary.append_tool_calls_from_llama(p.step_details) - - for event in stream_build_event( - chunk, chunk_id, metadata_map, media_type, conversation_id - ): - chunk_id += 1 - yield event - - # Extract token usage from the turn - token_usage = ( - extract_token_usage_from_turn(latest_turn) - if latest_turn is not None - else TokenCounter() - ) - - yield stream_end_event(metadata_map, summary, token_usage, media_type) - - if not is_transcripts_enabled(): - logger.debug("Transcript collection is disabled in the configuration") - else: - store_transcript( - user_id=user_id, - conversation_id=conversation_id, - model_id=model_id, - provider_id=provider_id, - query_is_valid=True, # TODO(lucasagomes): implement as part of query validation - query=query_request.query, - query_request=query_request, - summary=summary, - rag_chunks=create_rag_chunks_dict(summary), - truncated=False, # TODO(lucasagomes): implement truncation as part - # of quota work - attachments=query_request.attachments or [], - ) - - # Get the initial topic summary for the conversation - topic_summary = None - with get_session() as session: - existing_conversation = ( - session.query(UserConversation) - .filter_by(id=conversation_id) - .first() - ) - if not existing_conversation: - topic_summary = await get_topic_summary( - query_request.query, client, model_id - ) - - completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") - - referenced_documents = create_referenced_documents_with_metadata( - summary, metadata_map - ) - - cache_entry = CacheEntry( - query=query_request.query, - response=summary.llm_response, - provider=provider_id, - model=model_id, - started_at=started_at, - completed_at=completed_at, - referenced_documents=( - referenced_documents if referenced_documents else None - ), - ) - - store_conversation_into_cache( - configuration, - user_id, - conversation_id, - cache_entry, - _skip_userid_check, - topic_summary, - ) + # Create context object for response generator + context = ResponseGeneratorContext( + conversation_id=conversation_id, + user_id=user_id, + skip_userid_check=_skip_userid_check, + model_id=model_id, + provider_id=provider_id, + llama_stack_model_id=llama_stack_model_id, + query_request=query_request, + started_at=started_at, + client=client, + metadata_map=metadata_map, + ) - persist_user_conversation_details( - user_id=user_id, - conversation_id=conversation_id, - model=model_id, - provider_id=provider_id, - topic_summary=topic_summary, - ) + # Create the response generator using the provided factory function + response_generator = create_response_generator_func(context) # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() @@ -939,6 +933,38 @@ async def error_generator() -> AsyncGenerator[str, None]: return StreamingResponse(error_generator(), media_type=content_type) +@router.post("/streaming_query", responses=streaming_query_responses) +@authorize(Action.STREAMING_QUERY) +async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals,too-many-statements + request: Request, + query_request: QueryRequest, + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], + mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), +) -> StreamingResponse: + """ + Handle request to the /streaming_query endpoint using Agent API. + + This is a wrapper around streaming_query_endpoint_handler_base that provides + the Agent API specific retrieve_response and response generator functions. + + Returns: + StreamingResponse: An HTTP streaming response yielding + SSE-formatted events for the query lifecycle. + + Raises: + HTTPException: Returns HTTP 500 if unable to connect to the + Llama Stack server. + """ + return await streaming_query_endpoint_handler_base( + request=request, + query_request=query_request, + auth=auth, + mcp_headers=mcp_headers, + retrieve_response_func=retrieve_response, + create_response_generator_func=create_agent_response_generator, + ) + + async def retrieve_response( client: AsyncLlamaStackClient, model_id: str, diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py new file mode 100644 index 00000000..908a7237 --- /dev/null +++ b/src/app/endpoints/streaming_query_v2.py @@ -0,0 +1,410 @@ +"""Streaming query handler using Responses API (v2).""" + +import logging +from typing import Annotated, Any, AsyncIterator, cast + +from llama_stack_client import AsyncLlamaStackClient # type: ignore +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseObjectStream, +) + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import StreamingResponse + +from app.endpoints.query import ( + is_transcripts_enabled, + persist_user_conversation_details, + validate_attachments_metadata, +) +from app.endpoints.query_v2 import ( + extract_token_usage_from_responses_api, + get_topic_summary, + prepare_tools_for_responses_api, +) +from app.endpoints.streaming_query import ( + format_stream_data, + stream_end_event, + stream_start_event, + streaming_query_endpoint_handler_base, +) +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from configuration import configuration +from constants import MEDIA_TYPE_JSON +from models.config import Action +from models.context import ResponseGeneratorContext +from models.requests import QueryRequest +from models.responses import ForbiddenResponse, UnauthorizedResponse +from utils.endpoints import ( + cleanup_after_streaming, + get_system_prompt, +) +from utils.mcp_headers import mcp_headers_dependency +from utils.token_counter import TokenCounter +from utils.transcripts import store_transcript +from utils.types import TurnSummary, ToolCallSummary + +logger = logging.getLogger("app.endpoints.handlers") +router = APIRouter(tags=["streaming_query_v2"]) +auth_dependency = get_auth_dependency() + +streaming_query_v2_responses: dict[int | str, dict[str, Any]] = { + 200: { + "description": "Streaming response with Server-Sent Events", + "content": { + "application/json": { + "schema": { + "type": "string", + "example": ( + 'data: {"event": "start", ' + '"data": {"conversation_id": "123e4567-e89b-12d3-a456-426614174000"}}\n\n' + 'data: {"event": "token", "data": {"id": 0, "token": "Hello"}}\n\n' + 'data: {"event": "end", "data": {"referenced_documents": [], ' + '"truncated": null, "input_tokens": 0, "output_tokens": 0}, ' + '"available_quotas": {}}\n\n' + ), + } + }, + "text/plain": { + "schema": { + "type": "string", + "example": "Hello world!\n\n---\n\nReference: https://example.com/doc", + } + }, + }, + }, + 400: { + "description": "Missing or invalid credentials provided by client", + "model": UnauthorizedResponse, + }, + 401: { + "description": "Unauthorized: Invalid or missing Bearer token for k8s auth", + "model": UnauthorizedResponse, + }, + 403: { + "description": "User is not authorized", + "model": ForbiddenResponse, + }, + 500: { + "detail": { + "response": "Unable to connect to Llama Stack", + "cause": "Connection error.", + } + }, +} + + +def create_responses_response_generator( # pylint: disable=too-many-locals,too-many-statements + context: ResponseGeneratorContext, +) -> Any: + """ + Create a response generator function for Responses API streaming. + + This factory function returns an async generator that processes streaming + responses from the Responses API and yields Server-Sent Events (SSE). + + Args: + context: Context object containing all necessary parameters for response generation + + Returns: + An async generator function that yields SSE-formatted strings + """ + + async def response_generator( # pylint: disable=too-many-branches,too-many-statements + turn_response: AsyncIterator[OpenAIResponseObjectStream], + ) -> AsyncIterator[str]: + """ + Generate SSE formatted streaming response. + + Asynchronously generates a stream of Server-Sent Events + (SSE) representing incremental responses from a + language model turn. + + Yields start, token, tool call, turn completion, and + end events as SSE-formatted strings. Collects the + complete response for transcript storage if enabled. + """ + chunk_id = 0 + summary = TurnSummary(llm_response="", tool_calls=[]) + + # Determine media type for response formatting + media_type = context.query_request.media_type or MEDIA_TYPE_JSON + + # Accumulators for Responses API + text_parts: list[str] = [] + tool_item_registry: dict[str, dict[str, str]] = {} + emitted_turn_complete = False + + # Handle conversation id and start event in-band on response.created + conv_id = context.conversation_id + + # Track the latest response object from response.completed event + latest_response_object: Any | None = None + + logger.debug("Starting streaming response (Responses API) processing") + + async for chunk in turn_response: + event_type = getattr(chunk, "type", None) + logger.debug("Processing chunk %d, type: %s", chunk_id, event_type) + + # Emit start on response.created + if event_type == "response.created": + try: + conv_id = getattr(chunk, "response").id + except Exception: # pylint: disable=broad-except + logger.warning("Missing response id!") + conv_id = "" + yield stream_start_event(conv_id) + continue + + # Text streaming + if event_type == "response.output_text.delta": + delta = getattr(chunk, "delta", "") + if delta: + text_parts.append(delta) + yield format_stream_data( + { + "event": "token", + "data": { + "id": chunk_id, + "token": delta, + }, + } + ) + chunk_id += 1 + + # Final text of the output (capture, but emit at response.completed) + elif event_type == "response.output_text.done": + final_text = getattr(chunk, "text", "") + if final_text: + summary.llm_response = final_text + + # Content part started - emit an empty token to kick off UI streaming if desired + elif event_type == "response.content_part.added": + yield format_stream_data( + { + "event": "token", + "data": { + "id": chunk_id, + "token": "", + }, + } + ) + chunk_id += 1 + + # Track tool call items as they are added so we can build a summary later + elif event_type == "response.output_item.added": + item = getattr(chunk, "item", None) + item_type = getattr(item, "type", None) + if item and item_type == "function_call": + item_id = getattr(item, "id", "") + name = getattr(item, "name", "function_call") + call_id = getattr(item, "call_id", item_id) + if item_id: + tool_item_registry[item_id] = { + "name": name, + "call_id": call_id, + } + + # Stream tool call arguments as tool_call events + elif event_type == "response.function_call_arguments.delta": + delta = getattr(chunk, "delta", "") + yield format_stream_data( + { + "event": "tool_call", + "data": { + "id": chunk_id, + "role": "tool_execution", + "token": delta, + }, + } + ) + chunk_id += 1 + + # Finalize tool call arguments and append to summary + elif event_type in ( + "response.function_call_arguments.done", + "response.mcp_call.arguments.done", + ): + item_id = getattr(chunk, "item_id", "") + arguments = getattr(chunk, "arguments", "") + meta = tool_item_registry.get(item_id, {}) + summary.tool_calls.append( + ToolCallSummary( + id=meta.get("call_id", item_id or "unknown"), + name=meta.get("name", "tool_call"), + args=arguments, + response=None, + ) + ) + + # Completed response - capture final text and response object + elif event_type == "response.completed": + # Capture the response object for token usage extraction + latest_response_object = getattr(chunk, "response", None) + if not emitted_turn_complete: + final_message = summary.llm_response or "".join(text_parts) + if not final_message: + final_message = "No response from the model" + summary.llm_response = final_message + yield format_stream_data( + { + "event": "turn_complete", + "data": { + "id": chunk_id, + "token": final_message, + }, + } + ) + chunk_id += 1 + emitted_turn_complete = True + + # Ignore other event types for now; could add heartbeats if desired + + logger.debug( + "Streaming complete - Tool calls: %d, Response chars: %d", + len(summary.tool_calls), + len(summary.llm_response), + ) + + # Extract token usage from the response object + token_usage = ( + extract_token_usage_from_responses_api( + latest_response_object, context.model_id, context.provider_id + ) + if latest_response_object is not None + else TokenCounter() + ) + + yield stream_end_event(context.metadata_map, summary, token_usage, media_type) + + # Perform cleanup tasks (database and cache operations) + await cleanup_after_streaming( + user_id=context.user_id, + conversation_id=conv_id, + model_id=context.model_id, + provider_id=context.provider_id, + llama_stack_model_id=context.llama_stack_model_id, + query_request=context.query_request, + summary=summary, + metadata_map=context.metadata_map, + started_at=context.started_at, + client=context.client, + config=configuration, + skip_userid_check=context.skip_userid_check, + get_topic_summary_func=get_topic_summary, + is_transcripts_enabled_func=is_transcripts_enabled, + store_transcript_func=store_transcript, + persist_user_conversation_details_func=persist_user_conversation_details, + rag_chunks=[], # Responses API uses empty list for rag_chunks + ) + + return response_generator + + +@router.post("/streaming_query", responses=streaming_query_v2_responses) +@authorize(Action.STREAMING_QUERY) +async def streaming_query_endpoint_handler_v2( # pylint: disable=too-many-locals + request: Request, + query_request: QueryRequest, + auth: Annotated[AuthTuple, Depends(auth_dependency)], + mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), +) -> StreamingResponse: + """ + Handle request to the /streaming_query endpoint using Responses API. + + This is a wrapper around streaming_query_endpoint_handler_base that provides + the Responses API specific retrieve_response and response generator functions. + + Returns: + StreamingResponse: An HTTP streaming response yielding + SSE-formatted events for the query lifecycle. + + Raises: + HTTPException: Returns HTTP 500 if unable to connect to the + Llama Stack server. + """ + return await streaming_query_endpoint_handler_base( + request=request, + query_request=query_request, + auth=auth, + mcp_headers=mcp_headers, + retrieve_response_func=retrieve_response, + create_response_generator_func=create_responses_response_generator, + ) + + +async def retrieve_response( + client: AsyncLlamaStackClient, + model_id: str, + query_request: QueryRequest, + token: str, + mcp_headers: dict[str, dict[str, str]] | None = None, +) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str]: + """ + Retrieve response from LLMs and agents. + + Asynchronously retrieves a streaming response and conversation + ID from the Llama Stack agent for a given user query. + + This function configures input/output shields, system prompt, + and tool usage based on the request and environment. It + prepares the agent with appropriate headers and toolgroups, + validates attachments if present, and initiates a streaming + turn with the user's query and any provided documents. + + Parameters: + model_id (str): Identifier of the model to use for the query. + query_request (QueryRequest): The user's query and associated metadata. + token (str): Authentication token for downstream services. + mcp_headers (dict[str, dict[str, str]], optional): + Multi-cluster proxy headers for tool integrations. + + Returns: + tuple: A tuple containing the streaming response object + and the conversation ID. + """ + logger.info("Shields are not yet supported in Responses API.") + + # use system prompt from request or default one + system_prompt = get_system_prompt(query_request, configuration) + logger.debug("Using system prompt: %s", system_prompt) + + # TODO(lucasagomes): redact attachments content before sending to LLM + # if attachments are provided, validate them + if query_request.attachments: + validate_attachments_metadata(query_request.attachments) + + # Prepare tools for responses API + toolgroups = await prepare_tools_for_responses_api( + client, query_request, token, configuration, mcp_headers + ) + + # Prepare input for Responses API + # Convert attachments to text and concatenate with query + input_text = query_request.query + if query_request.attachments: + for attachment in query_request.attachments: + input_text += ( + f"\n\n[Attachment: {attachment.attachment_type}]\n" + f"{attachment.content}" + ) + + create_params: dict[str, Any] = { + "input": input_text, + "model": model_id, + "instructions": system_prompt, + "stream": True, + "store": True, + "tools": toolgroups, + } + if query_request.conversation_id: + create_params["previous_response_id"] = query_request.conversation_id + + response = await client.responses.create(**create_params) + response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response) + + # For streaming responses, the ID arrives in the first 'response.created' chunk + # Return empty conversation_id here; it will be set once the first chunk is received + return response_stream, "" diff --git a/src/app/routers.py b/src/app/routers.py index 1ca7044c..521db30c 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -13,6 +13,7 @@ config, feedback, streaming_query, + streaming_query_v2, authorized, conversations, conversations_v2, @@ -45,6 +46,7 @@ def include_routers(app: FastAPI) -> None: # V2 endpoints - Response API support app.include_router(query_v2.router, prefix="/v2") + app.include_router(streaming_query_v2.router, prefix="/v2") # road-core does not version these endpoints app.include_router(health.router) diff --git a/src/models/context.py b/src/models/context.py new file mode 100644 index 00000000..a6785167 --- /dev/null +++ b/src/models/context.py @@ -0,0 +1,48 @@ +"""Context objects for internal operations.""" + +from dataclasses import dataclass +from typing import Any + +from llama_stack_client import AsyncLlamaStackClient + +from models.requests import QueryRequest + + +@dataclass +class ResponseGeneratorContext: # pylint: disable=too-many-instance-attributes + """ + Context object for response generator creation. + + This class groups all the parameters needed to create a response generator + for streaming query endpoints, reducing function parameter count from 10 to 1. + + Attributes: + conversation_id: The conversation identifier + user_id: The user identifier + skip_userid_check: Whether to skip user ID validation + model_id: The model identifier + provider_id: The provider identifier + llama_stack_model_id: The full llama stack model ID + query_request: The query request object + started_at: Timestamp when the request started (ISO 8601 format) + client: The Llama Stack client for API interactions + metadata_map: Dictionary for storing metadata from tool responses + """ + + # Conversation & User context + conversation_id: str + user_id: str + skip_userid_check: bool + + # Model & Provider info + model_id: str + provider_id: str + llama_stack_model_id: str + + # Request & Timing + query_request: QueryRequest + started_at: str + + # Dependencies & State + client: AsyncLlamaStackClient + metadata_map: dict[str, dict[str, Any]] diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index de2e9bec..80b3b6e5 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -1,6 +1,7 @@ """Utility functions for endpoint handlers.""" from contextlib import suppress +from datetime import UTC, datetime from typing import Any from fastapi import HTTPException, status from llama_stack_client._client import AsyncLlamaStackClient @@ -591,3 +592,120 @@ def create_referenced_documents_from_chunks( ReferencedDocument(doc_url=doc_url, doc_title=doc_title) for doc_url, doc_title in document_entries ] + + +# pylint: disable=R0913,R0917,too-many-locals +async def cleanup_after_streaming( + user_id: str, + conversation_id: str, + model_id: str, + provider_id: str, + llama_stack_model_id: str, + query_request: QueryRequest, + summary: TurnSummary, + metadata_map: dict[str, Any], + started_at: str, + client: AsyncLlamaStackClient, + config: AppConfig, + skip_userid_check: bool, + get_topic_summary_func: Any, + is_transcripts_enabled_func: Any, + store_transcript_func: Any, + persist_user_conversation_details_func: Any, + rag_chunks: list[dict[str, Any]] | None = None, +) -> None: + """ + Perform cleanup tasks after streaming is complete. + + This function handles all database and cache operations after the streaming + response has been sent to the client. It is shared between Agent API and + Responses API streaming implementations. + + Args: + user_id: ID of the user making the request + conversation_id: ID of the conversation + model_id: ID of the model used + provider_id: ID of the provider used + llama_stack_model_id: Full Llama Stack model ID (provider/model format) + query_request: The original query request + summary: Summary of the turn including LLM response and tool calls + metadata_map: Metadata about referenced documents + started_at: Timestamp when the request started + client: AsyncLlamaStackClient instance + config: Application configuration + skip_userid_check: Whether to skip user ID checks + get_topic_summary_func: Function to get topic summary (API-specific) + is_transcripts_enabled_func: Function to check if transcripts are enabled + store_transcript_func: Function to store transcript + persist_user_conversation_details_func: Function to persist conversation details + rag_chunks: Optional RAG chunks dict (for Agent API, None for Responses API) + """ + # Store transcript if enabled + if not is_transcripts_enabled_func(): + logger.debug("Transcript collection is disabled in the configuration") + else: + # Prepare attachments + attachments = query_request.attachments or [] + + # Determine rag_chunks: use provided value or empty list + transcript_rag_chunks = rag_chunks if rag_chunks is not None else [] + + store_transcript_func( + user_id=user_id, + conversation_id=conversation_id, + model_id=model_id, + provider_id=provider_id, + query_is_valid=True, + query=query_request.query, + query_request=query_request, + summary=summary, + rag_chunks=transcript_rag_chunks, + truncated=False, + attachments=attachments, + ) + + # Get the initial topic summary for the conversation + topic_summary = None + with get_session() as session: + existing_conversation = ( + session.query(UserConversation).filter_by(id=conversation_id).first() + ) + if not existing_conversation: + topic_summary = await get_topic_summary_func( + query_request.query, + client, + llama_stack_model_id, + ) + + completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + + referenced_documents = create_referenced_documents_with_metadata( + summary, metadata_map + ) + + cache_entry = CacheEntry( + query=query_request.query, + response=summary.llm_response, + provider=provider_id, + model=model_id, + started_at=started_at, + completed_at=completed_at, + referenced_documents=referenced_documents if referenced_documents else None, + ) + + store_conversation_into_cache( + config, + user_id, + conversation_id, + cache_entry, + skip_userid_check, + topic_summary, + ) + + persist_user_conversation_details_func( + user_id=user_id, + conversation_id=conversation_id, + model=model_id, + provider_id=provider_id, + topic_summary=topic_summary, + ) diff --git a/tests/unit/app/endpoints/test_query_v2.py b/tests/unit/app/endpoints/test_query_v2.py index 0790f151..b51c84bb 100644 --- a/tests/unit/app/endpoints/test_query_v2.py +++ b/tests/unit/app/endpoints/test_query_v2.py @@ -2,6 +2,7 @@ """Unit tests for the /query (v2) REST API endpoint using Responses API.""" import pytest +from pytest_mock import MockerFixture from fastapi import HTTPException, status, Request from llama_stack_client import APIConnectionError @@ -24,7 +25,7 @@ def dummy_request() -> Request: return req -def test_get_rag_tools(): +def test_get_rag_tools() -> None: """Test get_rag_tools returns None for empty list and correct tool format for vector stores.""" assert get_rag_tools([]) is None @@ -35,7 +36,7 @@ def test_get_rag_tools(): assert tools[0]["max_num_results"] == 10 -def test_get_mcp_tools_with_and_without_token(): +def test_get_mcp_tools_with_and_without_token() -> None: """Test get_mcp_tools generates correct tool definitions with and without auth tokens.""" servers = [ ModelContextProtocolServer(name="fs", url="http://localhost:3000"), @@ -57,8 +58,46 @@ def test_get_mcp_tools_with_and_without_token(): assert tools_with_token[1]["headers"] == {"Authorization": "Bearer abc"} +def test_get_mcp_tools_with_mcp_headers() -> None: + """Test get_mcp_tools merges token auth and per-server headers correctly.""" + servers = [ + ModelContextProtocolServer(name="fs", url="http://localhost:3000"), + ModelContextProtocolServer(name="git", url="https://git.example.com/mcp"), + ] + + # Test with mcp_headers only (no token) + mcp_headers = { + "http://localhost:3000": {"X-Custom-Header": "value1"}, + "https://git.example.com/mcp": {"X-API-Key": "secret123"}, + } + tools = get_mcp_tools(servers, token=None, mcp_headers=mcp_headers) + assert len(tools) == 2 + assert tools[0]["headers"] == {"X-Custom-Header": "value1"} + assert tools[1]["headers"] == {"X-API-Key": "secret123"} + + # Test with both token and mcp_headers (should merge) + tools_merged = get_mcp_tools(servers, token="abc", mcp_headers=mcp_headers) + assert len(tools_merged) == 2 + assert tools_merged[0]["headers"] == { + "Authorization": "Bearer abc", + "X-Custom-Header": "value1", + } + assert tools_merged[1]["headers"] == { + "Authorization": "Bearer abc", + "X-API-Key": "secret123", + } + + # Test mcp_headers can override Authorization + override_headers = { + "http://localhost:3000": {"Authorization": "Custom auth"}, + } + tools_override = get_mcp_tools(servers, token="abc", mcp_headers=override_headers) + assert tools_override[0]["headers"] == {"Authorization": "Custom auth"} + assert tools_override[1]["headers"] == {"Authorization": "Bearer abc"} + + @pytest.mark.asyncio -async def test_retrieve_response_no_tools_bypasses_tools(mocker): +async def test_retrieve_response_no_tools_bypasses_tools(mocker: MockerFixture) -> None: """Test that no_tools=True bypasses tool configuration and passes None to responses API.""" mock_client = mocker.Mock() # responses.create returns a synthetic OpenAI-like response @@ -94,7 +133,9 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker): @pytest.mark.asyncio -async def test_retrieve_response_builds_rag_and_mcp_tools(mocker): +async def test_retrieve_response_builds_rag_and_mcp_tools( + mocker: MockerFixture, +) -> None: """Test that retrieve_response correctly builds RAG and MCP tools from configuration.""" mock_client = mocker.Mock() response_obj = mocker.Mock() @@ -137,7 +178,9 @@ async def test_retrieve_response_builds_rag_and_mcp_tools(mocker): @pytest.mark.asyncio -async def test_retrieve_response_parses_output_and_tool_calls(mocker): +async def test_retrieve_response_parses_output_and_tool_calls( + mocker: MockerFixture, +) -> None: """Test that retrieve_response correctly parses output content and tool calls from response.""" mock_client = mocker.Mock() @@ -190,7 +233,7 @@ async def test_retrieve_response_parses_output_and_tool_calls(mocker): @pytest.mark.asyncio -async def test_retrieve_response_with_usage_info(mocker): +async def test_retrieve_response_with_usage_info(mocker: MockerFixture) -> None: """Test that token usage is extracted when provided by the API as an object.""" mock_client = mocker.Mock() @@ -231,7 +274,7 @@ async def test_retrieve_response_with_usage_info(mocker): @pytest.mark.asyncio -async def test_retrieve_response_with_usage_dict(mocker): +async def test_retrieve_response_with_usage_dict(mocker: MockerFixture) -> None: """Test that token usage is extracted when provided by the API as a dict.""" mock_client = mocker.Mock() @@ -268,7 +311,7 @@ async def test_retrieve_response_with_usage_dict(mocker): @pytest.mark.asyncio -async def test_retrieve_response_with_empty_usage_dict(mocker): +async def test_retrieve_response_with_empty_usage_dict(mocker: MockerFixture) -> None: """Test that empty usage dict is handled gracefully.""" mock_client = mocker.Mock() @@ -305,7 +348,7 @@ async def test_retrieve_response_with_empty_usage_dict(mocker): @pytest.mark.asyncio -async def test_retrieve_response_validates_attachments(mocker): +async def test_retrieve_response_validates_attachments(mocker: MockerFixture) -> None: """Test that retrieve_response validates attachments and includes them in the input string.""" mock_client = mocker.Mock() response_obj = mocker.Mock() @@ -345,7 +388,9 @@ async def test_retrieve_response_validates_attachments(mocker): @pytest.mark.asyncio -async def test_query_endpoint_handler_v2_success(mocker, dummy_request): +async def test_query_endpoint_handler_v2_success( + mocker: MockerFixture, dummy_request: Request +) -> None: """Test successful query endpoint handler execution with proper response structure.""" # Mock configuration to avoid configuration not loaded errors mock_config = mocker.Mock() @@ -396,15 +441,17 @@ async def test_query_endpoint_handler_v2_success(mocker, dummy_request): @pytest.mark.asyncio -async def test_query_endpoint_handler_v2_api_connection_error(mocker, dummy_request): +async def test_query_endpoint_handler_v2_api_connection_error( + mocker: MockerFixture, dummy_request: Request +) -> None: """Test that query endpoint handler properly handles and reports API connection errors.""" # Mock configuration to avoid configuration not loaded errors mock_config = mocker.Mock() mock_config.llama_stack_configuration = mocker.Mock() mocker.patch("app.endpoints.query_v2.configuration", mock_config) - def _raise(*_args, **_kwargs): - raise APIConnectionError(request=None) + def _raise(*_args: object, **_kwargs: object) -> None: + raise APIConnectionError(request=None) # type: ignore[arg-type] mocker.patch("client.AsyncLlamaStackClientHolder.get_client", side_effect=_raise) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index bca6d2f2..2d577c49 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -52,12 +52,10 @@ from authorization.resolvers import NoopRolesResolver from constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT -from models.cache_entry import CacheEntry from models.config import ModelContextProtocolServer, Action from models.requests import QueryRequest, Attachment -from models.responses import RAGChunk from utils.token_counter import TokenCounter -from utils.types import ToolCallSummary, TurnSummary +from utils.types import TurnSummary from tests.unit.conftest import AgentFixtures @@ -75,14 +73,11 @@ def mock_database_operations(mocker: MockerFixture) -> None: "app.endpoints.streaming_query.validate_conversation_ownership", return_value=True, ) - mocker.patch("app.endpoints.streaming_query.persist_user_conversation_details") - - # Mock the database session and query - mock_session = mocker.Mock() - mock_session.query.return_value.filter_by.return_value.first.return_value = None - mock_session.__enter__ = mocker.Mock(return_value=mock_session) - mock_session.__exit__ = mocker.Mock(return_value=None) - mocker.patch("app.endpoints.streaming_query.get_session", return_value=mock_session) + # Mock the cleanup function that handles all post-streaming database/cache work + mocker.patch( + "app.endpoints.streaming_query.cleanup_after_streaming", + mocker.AsyncMock(return_value=None), + ) def mock_metrics(mocker: MockerFixture) -> None: @@ -217,9 +212,7 @@ async def test_streaming_query_endpoint_on_connection_error( # pylint: disable=too-many-locals -async def _test_streaming_query_endpoint_handler( - mocker: MockerFixture, store_transcript: bool = False -) -> None: +async def _test_streaming_query_endpoint_handler(mocker: MockerFixture) -> None: """Test the streaming query endpoint handler.""" mock_client = mocker.AsyncMock() mock_async_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") @@ -307,9 +300,6 @@ async def _test_streaming_query_endpoint_handler( ), ] - mock_store_in_cache = mocker.patch( - "app.endpoints.streaming_query.store_conversation_into_cache" - ) query = "What is OpenStack?" mocker.patch( "app.endpoints.streaming_query.retrieve_response", @@ -319,17 +309,6 @@ async def _test_streaming_query_endpoint_handler( "app.endpoints.streaming_query.select_model_and_provider_id", return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), ) - mocker.patch( - "app.endpoints.streaming_query.is_transcripts_enabled", - return_value=store_transcript, - ) - mock_transcript = mocker.patch("app.endpoints.streaming_query.store_transcript") - - # Mock get_topic_summary function - mocker.patch( - "app.endpoints.streaming_query.get_topic_summary", - return_value="Test topic summary", - ) mock_database_operations(mocker) @@ -371,79 +350,21 @@ async def _test_streaming_query_endpoint_handler( assert len(referenced_documents) == 2 assert referenced_documents[1]["doc_title"] == "Doc2" - # Assert that mock was called and get the arguments - mock_store_in_cache.assert_called_once() - call_args = mock_store_in_cache.call_args[0] - # Extract CacheEntry object from the call arguments, - # it's the 4th argument from the func signature - cached_entry = call_args[3] - - # Assert that the CacheEntry was constructed correctly - assert isinstance(cached_entry, CacheEntry) - assert cached_entry.response == "LLM answer" - assert cached_entry.referenced_documents is not None - assert len(cached_entry.referenced_documents) == 2 - assert cached_entry.referenced_documents[0].doc_title == "Doc1" - assert ( - str(cached_entry.referenced_documents[1].doc_url) == "https://example.com/doc2" - ) - - # Assert the store_transcript function is called if transcripts are enabled - if store_transcript: - mock_transcript.assert_called_once_with( - user_id="017adfa4-7cc6-46e4-b663-3653e1ae69df", - conversation_id="00000000-0000-0000-0000-000000000000", - model_id="fake_model_id", - provider_id="fake_provider_id", - query_is_valid=True, - query=query, - query_request=query_request, - summary=TurnSummary( - llm_response="LLM answer", - tool_calls=[ - ToolCallSummary( - id="t1", - name="knowledge_search", - args={}, - response=" ".join(SAMPLE_KNOWLEDGE_SEARCH_RESULTS), - ) - ], - rag_chunks=[ - RAGChunk( - content=" ".join(SAMPLE_KNOWLEDGE_SEARCH_RESULTS), - source="knowledge_search", - score=None, - ) - ], - ), - attachments=[], - rag_chunks=[ - { - "content": " ".join(SAMPLE_KNOWLEDGE_SEARCH_RESULTS), - "source": "knowledge_search", - "score": None, - } - ], - truncated=False, - ) - else: - mock_transcript.assert_not_called() - @pytest.mark.asyncio async def test_streaming_query_endpoint_handler(mocker: MockerFixture) -> None: - """Test the streaming query endpoint handler with transcript storage disabled.""" + """Test the streaming query endpoint handler.""" mock_metrics(mocker) - await _test_streaming_query_endpoint_handler(mocker, store_transcript=False) + await _test_streaming_query_endpoint_handler(mocker) @pytest.mark.asyncio async def test_streaming_query_endpoint_handler_store_transcript( mocker: MockerFixture, ) -> None: - """Test the streaming query endpoint handler with transcript storage enabled.""" + """Test the streaming query endpoint handler (backwards compatibility).""" mock_metrics(mocker) - await _test_streaming_query_endpoint_handler(mocker, store_transcript=True) + await _test_streaming_query_endpoint_handler(mocker) async def test_retrieve_response_vector_db_available( diff --git a/tests/unit/app/endpoints/test_streaming_query_v2.py b/tests/unit/app/endpoints/test_streaming_query_v2.py new file mode 100644 index 00000000..8d81c9e4 --- /dev/null +++ b/tests/unit/app/endpoints/test_streaming_query_v2.py @@ -0,0 +1,224 @@ +# pylint: disable=redefined-outer-name, import-error +"""Unit tests for the /streaming_query (v2) endpoint using Responses API.""" + +from types import SimpleNamespace +from typing import Any, AsyncIterator +import pytest +from pytest_mock import MockerFixture +from fastapi import HTTPException, status, Request +from fastapi.responses import StreamingResponse + +from llama_stack_client import APIConnectionError + +from models.requests import QueryRequest +from models.config import Action, ModelContextProtocolServer + +from app.endpoints.streaming_query_v2 import ( + retrieve_response, + streaming_query_endpoint_handler_v2, +) + + +@pytest.fixture +def dummy_request() -> Request: + """Create a dummy FastAPI Request for testing with authorized actions.""" + req = Request(scope={"type": "http"}) + # Provide a permissive authorized_actions set to satisfy RBAC check + req.state.authorized_actions = set(Action) + return req + + +@pytest.mark.asyncio +async def test_retrieve_response_builds_rag_and_mcp_tools( + mocker: MockerFixture, +) -> None: + """Test that retrieve_response correctly builds RAG and MCP tools.""" + mock_client = mocker.Mock() + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [mocker.Mock(id="db1")] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock()) + + mocker.patch( + "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" + ) + + mock_cfg = mocker.Mock() + mock_cfg.mcp_servers = [ + ModelContextProtocolServer(name="fs", url="http://localhost:3000"), + ] + mocker.patch("app.endpoints.streaming_query_v2.configuration", mock_cfg) + + qr = QueryRequest(query="hello") + await retrieve_response(mock_client, "model-z", qr, token="tok") + + kwargs = mock_client.responses.create.call_args.kwargs + assert kwargs["stream"] is True + tools = kwargs["tools"] + assert isinstance(tools, list) + types = {t.get("type") for t in tools} + assert types == {"file_search", "mcp"} + + +@pytest.mark.asyncio +async def test_retrieve_response_no_tools_passes_none(mocker: MockerFixture) -> None: + """Test that retrieve_response passes None for tools when no_tools=True.""" + mock_client = mocker.Mock() + mock_vector_stores = mocker.Mock() + mock_vector_stores.data = [] + mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock()) + + mocker.patch( + "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" + ) + mocker.patch( + "app.endpoints.streaming_query_v2.configuration", mocker.Mock(mcp_servers=[]) + ) + + qr = QueryRequest(query="hello", no_tools=True) + await retrieve_response(mock_client, "model-z", qr, token="tok") + + kwargs = mock_client.responses.create.call_args.kwargs + assert kwargs["tools"] is None + assert kwargs["stream"] is True + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_handler_v2_success_yields_events( + mocker: MockerFixture, dummy_request: Request +) -> None: + """Test that streaming_query_endpoint_handler_v2 yields correct SSE events.""" + # Skip real config checks - patch in streaming_query where the base handler is + mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") + + # Model selection plumbing + mock_client = mocker.Mock() + mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()]) + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + mocker.patch( + "app.endpoints.streaming_query.evaluate_model_hints", + return_value=(None, None), + ) + mocker.patch( + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("llama/m", "m", "p"), + ) + + # Replace SSE helpers for deterministic output + mocker.patch( + "app.endpoints.streaming_query_v2.stream_start_event", + lambda conv_id: f"START:{conv_id}\n", + ) + mocker.patch( + "app.endpoints.streaming_query_v2.format_stream_data", + lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n", + ) + mocker.patch( + "app.endpoints.streaming_query_v2.stream_end_event", + lambda _m, _s, _t, _media: "END\n", + ) + + # Mock the cleanup function that handles all post-streaming database/cache work + cleanup_spy = mocker.patch( + "app.endpoints.streaming_query_v2.cleanup_after_streaming", + mocker.AsyncMock(return_value=None), + ) + + # Build a fake async stream of chunks + async def fake_stream() -> AsyncIterator[SimpleNamespace]: + yield SimpleNamespace( + type="response.created", response=SimpleNamespace(id="conv-xyz") + ) + yield SimpleNamespace(type="response.content_part.added") + yield SimpleNamespace(type="response.output_text.delta", delta="Hello ") + yield SimpleNamespace(type="response.output_text.delta", delta="world") + yield SimpleNamespace( + type="response.output_item.added", + item=SimpleNamespace( + type="function_call", id="item1", name="search", call_id="call1" + ), + ) + yield SimpleNamespace( + type="response.function_call_arguments.delta", delta='{"q":"x"}' + ) + yield SimpleNamespace( + type="response.function_call_arguments.done", + item_id="item1", + arguments='{"q":"x"}', + ) + yield SimpleNamespace(type="response.output_text.done", text="Hello world") + yield SimpleNamespace(type="response.completed") + + mocker.patch( + "app.endpoints.streaming_query_v2.retrieve_response", + return_value=(fake_stream(), ""), + ) + + metric = mocker.patch("metrics.llm_calls_total") + + resp = await streaming_query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", True, "token-abc"), # skip_userid_check=True + mcp_headers={}, + ) + + assert isinstance(resp, StreamingResponse) + metric.labels("p", "m").inc.assert_called_once() + + # Collect emitted events + events: list[str] = [] + async for chunk in resp.body_iterator: + s = chunk.decode() if isinstance(chunk, (bytes, bytearray)) else str(chunk) + events.append(s) + + # Validate event sequence and content + assert events[0] == "START:conv-xyz\n" + # content_part.added triggers empty token + assert events[1] == "EV:token:\n" + assert events[2] == "EV:token:Hello \n" + assert events[3] == "EV:token:world\n" + # tool call delta + assert events[4].startswith("EV:tool_call:") + # turn complete and end + assert "EV:turn_complete:Hello world\n" in events + assert events[-1] == "END\n" + + # Verify cleanup function was invoked after streaming + assert cleanup_spy.call_count == 1 + # Verify cleanup was called with correct user_id and conversation_id + call_args = cleanup_spy.call_args + assert call_args.kwargs["user_id"] == "user123" + assert call_args.kwargs["conversation_id"] == "conv-xyz" + assert call_args.kwargs["model_id"] == "m" + assert call_args.kwargs["provider_id"] == "p" + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_handler_v2_api_connection_error( + mocker: MockerFixture, dummy_request: Request +) -> None: + """Test that streaming_query_endpoint_handler_v2 handles API connection errors.""" + mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") + + def _raise(*_a: Any, **_k: Any) -> None: + raise APIConnectionError(request=None) # type: ignore[arg-type] + + mocker.patch("client.AsyncLlamaStackClientHolder.get_client", side_effect=_raise) + + fail_metric = mocker.patch("metrics.llm_calls_failures_total") + + with pytest.raises(HTTPException) as exc: + await streaming_query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", False, "tok"), + mcp_headers={}, + ) + + assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Unable to connect to Llama Stack" in str(exc.value.detail) + fail_metric.inc.assert_called_once() diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index ac7b6aeb..01264b17 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -20,6 +20,7 @@ config, feedback, streaming_query, + streaming_query_v2, authorized, metrics, tools, @@ -65,7 +66,7 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 16 + assert len(app.routers) == 17 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() @@ -75,6 +76,7 @@ def test_include_routers() -> None: assert query.router in app.get_routers() assert query_v2.router in app.get_routers() assert streaming_query.router in app.get_routers() + assert streaming_query_v2.router in app.get_routers() assert config.router in app.get_routers() assert feedback.router in app.get_routers() assert health.router in app.get_routers() @@ -90,7 +92,7 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 16 + assert len(app.routers) == 17 assert app.get_router_prefix(root.router) == "" assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" @@ -100,6 +102,7 @@ def test_check_prefixes() -> None: assert app.get_router_prefix(query.router) == "/v1" assert app.get_router_prefix(streaming_query.router) == "/v1" assert app.get_router_prefix(query_v2.router) == "/v2" + assert app.get_router_prefix(streaming_query_v2.router) == "/v2" assert app.get_router_prefix(config.router) == "/v1" assert app.get_router_prefix(feedback.router) == "/v1" assert app.get_router_prefix(health.router) == ""