Skip to content

Commit f022501

Browse files
committed
implements referenced documents on /query
1 parent 2cc494c commit f022501

File tree

3 files changed

+289
-40
lines changed

3 files changed

+289
-40
lines changed

src/app/endpoints/query.py

Lines changed: 90 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Handler for REST API call to provide answer to query."""
22

3+
import ast
34
from datetime import datetime, UTC
45
import json
56
import logging
67
import os
78
from pathlib import Path
8-
from typing import Annotated, Any
9+
import re
10+
from typing import Annotated, Any, cast
911

1012
from llama_stack_client import APIConnectionError
1113
from llama_stack_client import AsyncLlamaStackClient # type: ignore
@@ -41,10 +43,79 @@
4143
router = APIRouter(tags=["query"])
4244
auth_dependency = get_auth_dependency()
4345

46+
METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")
47+
48+
49+
def _process_knowledge_search_content(
50+
tool_response: Any, metadata_map: dict[str, dict[str, Any]]
51+
) -> None:
52+
"""Process knowledge search tool response content for metadata."""
53+
for text_content_item in tool_response.content:
54+
if not hasattr(text_content_item, "text"):
55+
continue
56+
57+
for match in METADATA_PATTERN.findall(text_content_item.text):
58+
try:
59+
meta = ast.literal_eval(match)
60+
if "document_id" in meta:
61+
metadata_map[meta["document_id"]] = meta
62+
except Exception: # pylint: disable=broad-except
63+
logger.debug(
64+
"An exception was thrown in processing %s",
65+
match,
66+
)
67+
68+
69+
def extract_referenced_documents_from_steps(steps: list) -> list[dict[str, str]]:
70+
"""Extract referenced documents from tool execution steps.
71+
72+
Args:
73+
steps: List of response steps from the agent
74+
75+
Returns:
76+
List of referenced documents with doc_url and doc_title
77+
"""
78+
metadata_map: dict[str, dict[str, Any]] = {}
79+
80+
for step in steps:
81+
if step.step_type != "tool_execution" or not hasattr(step, "tool_responses"):
82+
continue
83+
84+
for tool_response in step.tool_responses:
85+
if (
86+
tool_response.tool_name != "knowledge_search"
87+
or not tool_response.content
88+
):
89+
continue
90+
91+
_process_knowledge_search_content(tool_response, metadata_map)
92+
93+
# Extract referenced documents from metadata
94+
return [
95+
{
96+
"doc_url": v["docs_url"],
97+
"doc_title": v["title"],
98+
}
99+
for v in filter(
100+
lambda v: ("docs_url" in v) and ("title" in v),
101+
metadata_map.values(),
102+
)
103+
]
104+
105+
44106
query_response: dict[int | str, dict[str, Any]] = {
45107
200: {
46108
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
47109
"response": "LLM answer",
110+
"referenced_documents": [
111+
{
112+
"doc_url": (
113+
"https://docs.openshift.com/container-platform/"
114+
"4.15/operators/olm/index.html"
115+
),
116+
"doc_title": "Operator Lifecycle Manager (OLM)",
117+
}
118+
],
48119
},
49120
400: {
50121
"description": "Missing or invalid credentials provided by client",
@@ -189,7 +260,7 @@ async def query_endpoint_handler(
189260
user_conversation=user_conversation, query_request=query_request
190261
),
191262
)
192-
response, conversation_id = await retrieve_response(
263+
response, conversation_id, referenced_documents = await retrieve_response(
193264
client,
194265
llama_stack_model_id,
195266
query_request,
@@ -223,7 +294,11 @@ async def query_endpoint_handler(
223294
provider_id=provider_id,
224295
)
225296

226-
return QueryResponse(conversation_id=conversation_id, response=response)
297+
return QueryResponse(
298+
conversation_id=conversation_id,
299+
response=response,
300+
referenced_documents=referenced_documents,
301+
)
227302

228303
# connection to Llama Stack server
229304
except APIConnectionError as e:
@@ -322,7 +397,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
322397
query_request: QueryRequest,
323398
token: str,
324399
mcp_headers: dict[str, dict[str, str]] | None = None,
325-
) -> tuple[str, str]:
400+
) -> tuple[str, str, list[dict[str, str]]]:
326401
"""Retrieve response from LLMs and agents."""
327402
available_input_shields = [
328403
shield.identifier
@@ -402,15 +477,24 @@ async def retrieve_response( # pylint: disable=too-many-locals
402477
toolgroups=toolgroups,
403478
)
404479

405-
# Check for validation errors in the response
480+
# Check for validation errors and extract referenced documents
406481
steps = getattr(response, "steps", [])
407482
for step in steps:
408483
if step.step_type == "shield_call" and step.violation:
409484
# Metric for LLM validation errors
410485
metrics.llm_calls_validation_errors_total.inc()
411486
break
412487

413-
return str(response.output_message.content), conversation_id # type: ignore[union-attr]
488+
# Extract referenced documents from tool execution steps
489+
referenced_documents = extract_referenced_documents_from_steps(steps)
490+
491+
# When stream=False, response should have output_message attribute
492+
response_obj = cast(Any, response)
493+
return (
494+
str(response_obj.output_message.content),
495+
conversation_id,
496+
referenced_documents,
497+
)
414498

415499

416500
def validate_attachments_metadata(attachments: list[Attachment]) -> None:

src/models/responses.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class ModelsResponse(BaseModel):
3636

3737
# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now
3838
# we are keeping it simple. The missing fields are:
39-
# - referenced_documents: The optional URLs and titles for the documents used
40-
# to generate the response.
4139
# - truncated: Set to True if conversation history was truncated to be within context window.
4240
# - input_tokens: Number of tokens sent to LLM
4341
# - output_tokens: Number of tokens received from LLM
@@ -51,6 +49,8 @@ class QueryResponse(BaseModel):
5149
Attributes:
5250
conversation_id: The optional conversation ID (UUID).
5351
response: The response.
52+
referenced_documents: The optional URLs and titles for the documents used
53+
to generate the response.
5454
"""
5555

5656
conversation_id: Optional[str] = Field(
@@ -66,13 +66,38 @@ class QueryResponse(BaseModel):
6666
],
6767
)
6868

69+
referenced_documents: list[dict[str, str]] = Field(
70+
default_factory=list,
71+
description="List of documents referenced in generating the response",
72+
examples=[
73+
[
74+
{
75+
"doc_url": (
76+
"https://docs.openshift.com/container-platform/"
77+
"4.15/operators/olm/index.html"
78+
),
79+
"doc_title": "Operator Lifecycle Manager (OLM)",
80+
}
81+
]
82+
],
83+
)
84+
6985
# provides examples for /docs endpoint
7086
model_config = {
7187
"json_schema_extra": {
7288
"examples": [
7389
{
7490
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
7591
"response": "Operator Lifecycle Manager (OLM) helps users install...",
92+
"referenced_documents": [
93+
{
94+
"doc_url": (
95+
"https://docs.openshift.com/container-platform/"
96+
"4.15/operators/olm/index.html"
97+
),
98+
"doc_title": "Operator Lifecycle Manager (OLM)",
99+
}
100+
],
76101
}
77102
]
78103
}

0 commit comments

Comments
 (0)