Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,26 +214,33 @@ async def get_topic_summary(
)


@router.post("/query", responses=query_response)
@authorize(Action.QUERY)
async def query_endpoint_handler( # pylint: disable=R0914
async def query_endpoint_handler_base( # pylint: disable=R0914
request: Request,
query_request: QueryRequest,
auth: Annotated[AuthTuple, Depends(get_auth_dependency())],
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
mcp_headers: dict[str, dict[str, str]],
retrieve_response_func: Any,
get_topic_summary_func: Any,
Comment on lines +222 to +223
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Use explicit Callable types instead of Any for function parameters.

The retrieve_response_func and get_topic_summary_func parameters are typed as Any, which violates the coding guideline requiring complete type annotations for all function parameters. This reduces type safety and makes the dependency injection pattern less clear.

As per coding guidelines

Apply this diff to add proper type hints:

+from typing import Annotated, Any, Callable, Optional, cast
+
 async def query_endpoint_handler_base(  # pylint: disable=R0914
     request: Request,
     query_request: QueryRequest,
     auth: Annotated[AuthTuple, Depends(get_auth_dependency())],
     mcp_headers: dict[str, dict[str, str]],
-    retrieve_response_func: Any,
-    get_topic_summary_func: Any,
+    retrieve_response_func: Callable[
+        [AsyncLlamaStackClient, str, QueryRequest, str, dict[str, dict[str, str]] | None, str],
+        tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]
+    ],
+    get_topic_summary_func: Callable[[str, AsyncLlamaStackClient, str], str],
 ) -> QueryResponse:

Note: The retrieve_response_func signature matches the actual function signature with positional provider_id parameter.

🤖 Prompt for AI Agents
In src/app/endpoints/query.py around lines 222-223, replace the Any types for
retrieve_response_func and get_topic_summary_func with explicit typing: import
Callable and Awaitable (and any specific model/response types used) and annotate
retrieve_response_func as a Callable whose first positional parameter is
provider_id (match the actual function parameter order and types) and whose
return is the appropriate Awaitable/return type; likewise annotate
get_topic_summary_func with its exact parameter types and return type. Ensure
the signatures exactly mirror the real functions (positional provider_id first
for retrieve_response_func), update imports accordingly, and run type checks.

) -> QueryResponse:
"""
Handle request to the /query endpoint.
Handle query endpoints (shared by Agent API and Responses API).

Processes a POST request to the /query endpoint, forwarding the
user's query to a selected Llama Stack LLM or agent and
returning the generated response.
Processes a POST request to a query endpoint, forwarding the
user's query to a selected Llama Stack LLM and returning the generated response.

Validates configuration and authentication, selects the appropriate model
and provider, retrieves the LLM response, updates metrics, and optionally
stores a transcript of the interaction. Handles connection errors to the
Llama Stack service by returning an HTTP 500 error.

Args:
request: The FastAPI request object
query_request: The query request containing the user's question
auth: Authentication tuple from dependency
mcp_headers: MCP headers from dependency
retrieve_response_func: The retrieve_response function to use (Agent or Responses API)
get_topic_summary_func: The get_topic_summary function to use (Agent or Responses API)

Returns:
QueryResponse: Contains the conversation ID and the LLM-generated response.
"""
Expand Down Expand Up @@ -288,7 +295,7 @@ async def query_endpoint_handler( # pylint: disable=R0914
),
)
summary, conversation_id, referenced_documents, token_usage = (
await retrieve_response(
await retrieve_response_func(
client,
llama_stack_model_id,
query_request,
Expand All @@ -305,8 +312,8 @@ async def query_endpoint_handler( # pylint: disable=R0914
session.query(UserConversation).filter_by(id=conversation_id).first()
)
if not existing_conversation:
topic_summary = await get_topic_summary(
query_request.query, client, model_id
topic_summary = await get_topic_summary_func(
query_request.query, client, llama_stack_model_id
)
# Convert RAG chunks to dictionary format once for reuse
logger.info("Processing RAG chunks...")
Expand Down Expand Up @@ -416,6 +423,33 @@ async def query_endpoint_handler( # pylint: disable=R0914
) from e


@router.post("/query", responses=query_response)
@authorize(Action.QUERY)
async def query_endpoint_handler(
request: Request,
query_request: QueryRequest,
auth: Annotated[AuthTuple, Depends(get_auth_dependency())],
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
) -> QueryResponse:
"""
Handle request to the /query endpoint using Agent API.

This is a wrapper around query_endpoint_handler_base that provides
the Agent API specific retrieve_response and get_topic_summary functions.

Returns:
QueryResponse: Contains the conversation ID and the LLM-generated response.
"""
return await query_endpoint_handler_base(
request=request,
query_request=query_request,
auth=auth,
mcp_headers=mcp_headers,
retrieve_response_func=retrieve_response,
get_topic_summary_func=get_topic_summary,
)


def select_model_and_provider_id(
models: ModelListResponse, model_id: str | None, provider_id: str | None
) -> tuple[str, str, str]:
Expand Down
Loading
Loading