|
1 | 1 | """Handler for REST API call to provide answer to query.""" |
2 | 2 |
|
| 3 | +import ast |
3 | 4 | from datetime import datetime, UTC |
4 | 5 | import json |
5 | 6 | import logging |
6 | 7 | import os |
7 | 8 | from pathlib import Path |
8 | | -from typing import Annotated, Any |
| 9 | +import re |
| 10 | +from typing import Annotated, Any, cast |
9 | 11 |
|
10 | 12 | from llama_stack_client import APIConnectionError |
11 | 13 | from llama_stack_client import AsyncLlamaStackClient # type: ignore |
|
41 | 43 | router = APIRouter(tags=["query"]) |
42 | 44 | auth_dependency = get_auth_dependency() |
43 | 45 |
|
| 46 | +METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n") |
| 47 | + |
| 48 | + |
| 49 | +def _process_knowledge_search_content( |
| 50 | + tool_response: Any, metadata_map: dict[str, dict[str, Any]] |
| 51 | +) -> None: |
| 52 | + """Process knowledge search tool response content for metadata.""" |
| 53 | + for text_content_item in tool_response.content: |
| 54 | + if not hasattr(text_content_item, "text"): |
| 55 | + continue |
| 56 | + |
| 57 | + for match in METADATA_PATTERN.findall(text_content_item.text): |
| 58 | + try: |
| 59 | + meta = ast.literal_eval(match) |
| 60 | + if "document_id" in meta: |
| 61 | + metadata_map[meta["document_id"]] = meta |
| 62 | + except Exception: # pylint: disable=broad-except |
| 63 | + logger.debug( |
| 64 | + "An exception was thrown in processing %s", |
| 65 | + match, |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +def extract_referenced_documents_from_steps(steps: list) -> list[dict[str, str]]: |
| 70 | + """Extract referenced documents from tool execution steps. |
| 71 | +
|
| 72 | + Args: |
| 73 | + steps: List of response steps from the agent |
| 74 | +
|
| 75 | + Returns: |
| 76 | + List of referenced documents with doc_url and doc_title |
| 77 | + """ |
| 78 | + metadata_map: dict[str, dict[str, Any]] = {} |
| 79 | + |
| 80 | + for step in steps: |
| 81 | + if step.step_type != "tool_execution" or not hasattr(step, "tool_responses"): |
| 82 | + continue |
| 83 | + |
| 84 | + for tool_response in step.tool_responses: |
| 85 | + if ( |
| 86 | + tool_response.tool_name != "knowledge_search" |
| 87 | + or not tool_response.content |
| 88 | + ): |
| 89 | + continue |
| 90 | + |
| 91 | + _process_knowledge_search_content(tool_response, metadata_map) |
| 92 | + |
| 93 | + # Extract referenced documents from metadata |
| 94 | + return [ |
| 95 | + { |
| 96 | + "doc_url": v["docs_url"], |
| 97 | + "doc_title": v["title"], |
| 98 | + } |
| 99 | + for v in filter( |
| 100 | + lambda v: ("docs_url" in v) and ("title" in v), |
| 101 | + metadata_map.values(), |
| 102 | + ) |
| 103 | + ] |
| 104 | + |
| 105 | + |
44 | 106 | query_response: dict[int | str, dict[str, Any]] = { |
45 | 107 | 200: { |
46 | 108 | "conversation_id": "123e4567-e89b-12d3-a456-426614174000", |
47 | 109 | "response": "LLM answer", |
| 110 | + "referenced_documents": [ |
| 111 | + { |
| 112 | + "doc_url": ( |
| 113 | + "https://docs.openshift.com/container-platform/" |
| 114 | + "4.15/operators/olm/index.html" |
| 115 | + ), |
| 116 | + "doc_title": "Operator Lifecycle Manager (OLM)", |
| 117 | + } |
| 118 | + ], |
48 | 119 | }, |
49 | 120 | 400: { |
50 | 121 | "description": "Missing or invalid credentials provided by client", |
@@ -189,7 +260,7 @@ async def query_endpoint_handler( |
189 | 260 | user_conversation=user_conversation, query_request=query_request |
190 | 261 | ), |
191 | 262 | ) |
192 | | - response, conversation_id = await retrieve_response( |
| 263 | + response, conversation_id, referenced_documents = await retrieve_response( |
193 | 264 | client, |
194 | 265 | llama_stack_model_id, |
195 | 266 | query_request, |
@@ -223,7 +294,11 @@ async def query_endpoint_handler( |
223 | 294 | provider_id=provider_id, |
224 | 295 | ) |
225 | 296 |
|
226 | | - return QueryResponse(conversation_id=conversation_id, response=response) |
| 297 | + return QueryResponse( |
| 298 | + conversation_id=conversation_id, |
| 299 | + response=response, |
| 300 | + referenced_documents=referenced_documents, |
| 301 | + ) |
227 | 302 |
|
228 | 303 | # connection to Llama Stack server |
229 | 304 | except APIConnectionError as e: |
@@ -322,7 +397,7 @@ async def retrieve_response( # pylint: disable=too-many-locals |
322 | 397 | query_request: QueryRequest, |
323 | 398 | token: str, |
324 | 399 | mcp_headers: dict[str, dict[str, str]] | None = None, |
325 | | -) -> tuple[str, str]: |
| 400 | +) -> tuple[str, str, list[dict[str, str]]]: |
326 | 401 | """Retrieve response from LLMs and agents.""" |
327 | 402 | available_input_shields = [ |
328 | 403 | shield.identifier |
@@ -402,15 +477,24 @@ async def retrieve_response( # pylint: disable=too-many-locals |
402 | 477 | toolgroups=toolgroups, |
403 | 478 | ) |
404 | 479 |
|
405 | | - # Check for validation errors in the response |
| 480 | + # Check for validation errors and extract referenced documents |
406 | 481 | steps = getattr(response, "steps", []) |
407 | 482 | for step in steps: |
408 | 483 | if step.step_type == "shield_call" and step.violation: |
409 | 484 | # Metric for LLM validation errors |
410 | 485 | metrics.llm_calls_validation_errors_total.inc() |
411 | 486 | break |
412 | 487 |
|
413 | | - return str(response.output_message.content), conversation_id # type: ignore[union-attr] |
| 488 | + # Extract referenced documents from tool execution steps |
| 489 | + referenced_documents = extract_referenced_documents_from_steps(steps) |
| 490 | + |
| 491 | + # When stream=False, response should have output_message attribute |
| 492 | + response_obj = cast(Any, response) |
| 493 | + return ( |
| 494 | + str(response_obj.output_message.content), |
| 495 | + conversation_id, |
| 496 | + referenced_documents, |
| 497 | + ) |
414 | 498 |
|
415 | 499 |
|
416 | 500 | def validate_attachments_metadata(attachments: list[Attachment]) -> None: |
|
0 commit comments