diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index d53d3e83..e9df9dd6 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Annotated, Any + from llama_stack_client import APIConnectionError from llama_stack_client import AsyncLlamaStackClient # type: ignore from llama_stack_client.types import UserMessage, Shield # type: ignore @@ -25,7 +26,12 @@ from app.database import get_session import metrics from models.database.conversations import UserConversation -from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse +from models.responses import ( + QueryResponse, + UnauthorizedResponse, + ForbiddenResponse, + ReferencedDocument, +) from models.requests import QueryRequest, Attachment import constants from utils.endpoints import ( @@ -36,15 +42,28 @@ ) from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.suid import get_suid +from utils.metadata import ( + extract_referenced_documents_from_steps, +) logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["query"]) auth_dependency = get_auth_dependency() + query_response: dict[int | str, dict[str, Any]] = { 200: { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "response": "LLM answer", + "referenced_documents": [ + { + "doc_url": ( + "https://docs.openshift.com/container-platform/" + "4.15/operators/olm/index.html" + ), + "doc_title": "Operator Lifecycle Manager (OLM)", + } + ], }, 400: { "description": "Missing or invalid credentials provided by client", @@ -54,7 +73,7 @@ "description": "User is not authorized", "model": ForbiddenResponse, }, - 503: { + 500: { "detail": { "response": "Unable to connect to Llama Stack", "cause": "Connection error.", @@ -203,7 +222,7 @@ async def query_endpoint_handler( user_conversation=user_conversation, query_request=query_request ), ) - response, conversation_id = await retrieve_response( + response, conversation_id, referenced_documents = await retrieve_response( client, llama_stack_model_id, query_request, @@ -237,7 +256,11 @@ async def query_endpoint_handler( provider_id=provider_id, ) - return QueryResponse(conversation_id=conversation_id, response=response) + return QueryResponse( + conversation_id=conversation_id, + response=response, + referenced_documents=referenced_documents, + ) # connection to Llama Stack server except APIConnectionError as e: @@ -381,7 +404,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche query_request: QueryRequest, token: str, mcp_headers: dict[str, dict[str, str]] | None = None, -) -> tuple[str, str]: +) -> tuple[str, str, list[ReferencedDocument]]: """ Retrieve response from LLMs and agents. @@ -404,8 +427,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. Returns: - tuple[str, str]: A tuple containing the LLM or agent's response content - and the conversation ID. + tuple[str, str, list[ReferencedDocument]]: A tuple containing the response + content, the conversation ID, and the list of referenced documents parsed + from tool execution steps. """ available_input_shields = [ shield.identifier @@ -485,26 +509,39 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche toolgroups=toolgroups, ) - # Check for validation errors in the response + # Check for validation errors and extract referenced documents steps = getattr(response, "steps", []) for step in steps: - if step.step_type == "shield_call" and step.violation: + if getattr(step, "step_type", "") == "shield_call" and getattr( + step, "violation", False + ): # Metric for LLM validation errors metrics.llm_calls_validation_errors_total.inc() - break + # Extract referenced documents from tool execution steps + referenced_documents = extract_referenced_documents_from_steps(steps) + + # When stream=False, response should have output_message attribute output_message = getattr(response, "output_message", None) if output_message is not None: content = getattr(output_message, "content", None) if content is not None: - return str(content), conversation_id + response_text = str(content) + else: + response_text = "" + else: + # fallback + logger.warning( + "Response lacks output_message.content (conversation_id=%s)", + conversation_id, + ) + response_text = "" - # fallback - logger.warning( - "Response lacks output_message.content (conversation_id=%s)", + return ( + response_text, conversation_id, + referenced_documents, ) - return "", conversation_id def validate_attachments_metadata(attachments: list[Attachment]) -> None: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 329a7230..bf6c0137 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1,11 +1,11 @@ """Handler for REST API call to provide answer to streaming query.""" -import ast import json -import re import logging from typing import Annotated, Any, AsyncIterator, Iterator +import pydantic + from llama_stack_client import APIConnectionError from llama_stack_client import AsyncLlamaStackClient # type: ignore from llama_stack_client.types import UserMessage # type: ignore @@ -24,8 +24,10 @@ import metrics from models.requests import QueryRequest from models.database.conversations import UserConversation +from models.responses import ReferencedDocument from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups +from utils.metadata import parse_knowledge_search_metadata from app.endpoints.query import ( get_rag_toolgroups, @@ -45,9 +47,6 @@ auth_dependency = get_auth_dependency() -METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n") - - def format_stream_data(d: dict) -> str: """ Format a dictionary as a Server-Sent Events (SSE) data string. @@ -102,20 +101,36 @@ def stream_end_event(metadata_map: dict) -> str: str: A Server-Sent Events (SSE) formatted string representing the end of the data stream. """ + # Create ReferencedDocument objects and convert them to serializable dict format + referenced_documents = [] + for v in filter( + lambda v: ("docs_url" in v) and ("title" in v), + metadata_map.values(), + ): + try: + doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"]) + referenced_documents.append( + { + "doc_url": str( + doc.doc_url + ), # Convert AnyUrl to string for JSON serialization + "doc_title": doc.doc_title, + } + ) + except (pydantic.ValidationError, ValueError) as e: + logger.warning( + "Skipping invalid referenced document with docs_url='%s', title='%s': %s", + v.get("docs_url", ""), + v.get("title", ""), + str(e), + ) + continue + return format_stream_data( { "event": "end", "data": { - "referenced_documents": [ - { - "doc_url": v["docs_url"], - "doc_title": v["title"], - } - for v in filter( - lambda v: ("docs_url" in v) and ("title" in v), - metadata_map.values(), - ) - ], + "referenced_documents": referenced_documents, "truncated": None, # TODO(jboos): implement truncated "input_tokens": 0, # TODO(jboos): implement input tokens "output_tokens": 0, # TODO(jboos): implement output tokens @@ -435,16 +450,16 @@ def _handle_tool_execution_event( newline_pos = summary.find("\n") if newline_pos > 0: summary = summary[:newline_pos] - for match in METADATA_PATTERN.findall(text_content_item.text): - try: - meta = ast.literal_eval(match) - if "document_id" in meta: - metadata_map[meta["document_id"]] = meta - except Exception: # pylint: disable=broad-except - logger.debug( - "An exception was thrown in processing %s", - match, - ) + try: + parsed_metadata = parse_knowledge_search_metadata( + text_content_item.text, strict=False + ) + metadata_map.update(parsed_metadata) + except ValueError as e: + logger.exception( + "Error processing metadata from text; position=%s", + getattr(e, "position", "unknown"), + ) yield format_stream_data( { diff --git a/src/models/responses.py b/src/models/responses.py index cb8ee09c..9cf0c0d3 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -2,7 +2,7 @@ from typing import Any, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, AnyUrl class ModelsResponse(BaseModel): @@ -36,8 +36,6 @@ class ModelsResponse(BaseModel): # TODO(lucasagomes): a lot of fields to add to QueryResponse. For now # we are keeping it simple. The missing fields are: -# - referenced_documents: The optional URLs and titles for the documents used -# to generate the response. # - truncated: Set to True if conversation history was truncated to be within context window. # - input_tokens: Number of tokens sent to LLM # - output_tokens: Number of tokens received from LLM @@ -45,12 +43,23 @@ class ModelsResponse(BaseModel): # - tool_calls: List of tool requests. # - tool_results: List of tool results. # See LLMResponse in ols-service for more details. + + +class ReferencedDocument(BaseModel): + """Model representing a document referenced in generating a response.""" + + doc_url: AnyUrl = Field(description="URL of the referenced document") + doc_title: str = Field(description="Title of the referenced document") + + class QueryResponse(BaseModel): """Model representing LLM response to a query. Attributes: conversation_id: The optional conversation ID (UUID). response: The response. + referenced_documents: The optional URLs and titles for the documents used + to generate the response. """ conversation_id: Optional[str] = Field( @@ -66,6 +75,22 @@ class QueryResponse(BaseModel): ], ) + referenced_documents: list[ReferencedDocument] = Field( + default_factory=list, + description="List of documents referenced in generating the response", + examples=[ + [ + { + "doc_url": ( + "https://docs.openshift.com/container-platform/" + "4.15/operators/olm/index.html" + ), + "doc_title": "Operator Lifecycle Manager (OLM)", + } + ] + ], + ) + # provides examples for /docs endpoint model_config = { "json_schema_extra": { @@ -73,6 +98,15 @@ class QueryResponse(BaseModel): { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "response": "Operator Lifecycle Manager (OLM) helps users install...", + "referenced_documents": [ + { + "doc_url": ( + "https://docs.openshift.com/container-platform/" + "4.15/operators/olm/index.html" + ), + "doc_title": "Operator Lifecycle Manager (OLM)", + } + ], } ] } diff --git a/src/utils/metadata.py b/src/utils/metadata.py new file mode 100644 index 00000000..991946ff --- /dev/null +++ b/src/utils/metadata.py @@ -0,0 +1,212 @@ +"""Shared utilities for parsing metadata from knowledge search responses.""" + +import ast +import json +import logging +import re +from typing import Any + +import pydantic + +from models.responses import ReferencedDocument + +logger = logging.getLogger(__name__) + + +# Case-insensitive pattern to find "Metadata:" labels +METADATA_LABEL_PATTERN = re.compile(r"^\s*metadata:\s*", re.MULTILINE | re.IGNORECASE) + + +def _extract_balanced_braces(text: str, start_pos: int) -> str: + """Extract a balanced brace substring starting from start_pos. + + Args: + text: The text to search + start_pos: Position where the opening brace should be + + Returns: + The balanced brace substring including the braces + + Raises: + ValueError: If no balanced braces are found + """ + if start_pos >= len(text) or text[start_pos] != "{": + raise ValueError("No opening brace found at start position") + + brace_count = 0 + pos = start_pos + + while pos < len(text): + char = text[pos] + if char == "{": + brace_count += 1 + elif char == "}": + brace_count -= 1 + if brace_count == 0: + return text[start_pos : pos + 1] + pos += 1 + + raise ValueError("Unmatched opening brace - no closing brace found") + + +def parse_knowledge_search_metadata( + text: str, *, strict: bool = True +) -> dict[str, dict[str, Any]]: + """Parse metadata from knowledge search text content. + + Args: + text: Text content that may contain metadata patterns + strict: If True (default), raise ValueError on first parsing error. + If False, skip invalid blocks and continue parsing. + + Returns: + Dictionary of document_id -> metadata mappings + + Raises: + ValueError: If metadata parsing fails due to invalid Python-literal or JSON-like syntax + (only in strict mode) + """ + metadata_map: dict[str, dict[str, Any]] = {} + + # Find all "Metadata:" labels (case-insensitive) + for match in METADATA_LABEL_PATTERN.finditer(text): + try: + # Find the position right after the "Metadata:" label + label_end = match.end() + + # Skip any whitespace after the label + pos = label_end + while pos < len(text) and text[pos].isspace(): + pos += 1 + + # Look for opening brace + if pos >= len(text) or text[pos] != "{": + continue # No brace found, skip this match + + # Extract balanced brace content + brace_content = _extract_balanced_braces(text, pos) + + # Parse the extracted content + meta = ast.literal_eval(brace_content) + + # Verify the result is a dict before accessing keys + if isinstance(meta, dict) and "document_id" in meta: + metadata_map[meta["document_id"]] = meta + + except (SyntaxError, ValueError) as e: + if strict: + raise ValueError( + f"Failed to parse metadata at position {match.start()}: {e}" + ) from e + # non-strict mode: skip bad blocks, keep the rest + continue + + return metadata_map + + +def process_knowledge_search_content(tool_response: Any) -> dict[str, dict[str, Any]]: + """Process knowledge search tool response content for metadata. + + Args: + tool_response: Tool response object containing content to parse + + Returns: + Dictionary mapping document_id to metadata dict + """ + metadata_map: dict[str, dict[str, Any]] = {} + + # Guard against missing tool_response or content + if not tool_response: + return metadata_map + + content = getattr(tool_response, "content", None) + if not content: + return metadata_map + + # Handle string content by attempting JSON parsing + if isinstance(content, str): + try: + content = json.loads(content, strict=False) + except (json.JSONDecodeError, TypeError): + # If JSON parsing fails, try parsing as metadata text + try: + parsed_metadata = parse_knowledge_search_metadata(content, strict=False) + metadata_map.update(parsed_metadata) + except ValueError as e: + logger.exception( + "Error processing string content as metadata; position=%s", + getattr(e, "position", "unknown"), + ) + return metadata_map + + # Ensure content is iterable (but not a string) + if isinstance(content, str): + return metadata_map + try: + iter(content) + except TypeError: + return metadata_map + + for text_content_item in content: + # Skip items that lack a non-empty "text" attribute + text = getattr(text_content_item, "text", None) + if not text: + continue + + try: + parsed_metadata = parse_knowledge_search_metadata(text, strict=False) + metadata_map.update(parsed_metadata) + except ValueError as e: + logger.exception( + "Error processing metadata from text; position=%s", + getattr(e, "position", "unknown"), + ) + + return metadata_map + + +def extract_referenced_documents_from_steps( + steps: list[Any], +) -> list[ReferencedDocument]: + """Extract referenced documents from tool execution steps. + + Args: + steps: List of response steps from the agent + + Returns: + List of referenced documents with doc_url and doc_title, sorted deterministically + """ + metadata_map: dict[str, dict[str, Any]] = {} + + for step in steps: + if getattr(step, "step_type", "") != "tool_execution" or not hasattr( + step, "tool_responses" + ): + continue + + for tool_response in getattr(step, "tool_responses", []) or []: + if getattr( + tool_response, "tool_name", "" + ) != "knowledge_search" or not getattr(tool_response, "content", []): + continue + + response_metadata = process_knowledge_search_content(tool_response) + metadata_map.update(response_metadata) + + # Extract referenced documents from metadata with error handling + referenced_documents = [] + for v in metadata_map.values(): + if "docs_url" in v and "title" in v: + try: + doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"]) + referenced_documents.append(doc) + except (pydantic.ValidationError, ValueError) as e: + logger.warning( + "Skipping invalid referenced document with docs_url='%s', title='%s': %s", + v.get("docs_url", ""), + v.get("title", ""), + str(e), + ) + continue + + return sorted(referenced_documents, key=lambda d: (d.doc_title, str(d.doc_url))) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 9f3d065f..4d948595 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -8,6 +8,8 @@ from llama_stack_client import APIConnectionError from llama_stack_client.types import UserMessage # type: ignore +from llama_stack_client.types.shared.interleaved_content_item import TextContentItem +from llama_stack_client.types.tool_response import ToolResponse from configuration import AppConfig from app.endpoints.query import ( @@ -21,13 +23,51 @@ get_rag_toolgroups, evaluate_model_hints, ) +from utils.metadata import ( + process_knowledge_search_content, + extract_referenced_documents_from_steps, +) from models.requests import QueryRequest, Attachment from models.config import ModelContextProtocolServer +from models.responses import ReferencedDocument from models.database.conversations import UserConversation MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") +SAMPLE_KNOWLEDGE_SEARCH_RESULTS = [ + """knowledge_search tool found 2 chunks: +BEGIN of knowledge_search tool results. +""", + """Result 1 +Content: ABC +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1', 'document_id': 'doc-1', \ +'source': None} +""", + """Result 2 +Content: ABC +Metadata: {'docs_url': 'https://example.com/doc2', 'title': 'Doc2', 'document_id': 'doc-2', \ +'source': None} +""", + """Result 2b +Content: ABC + Metadata: {'docs_url': 'https://example.com/doc2b', 'title': 'Doc2b', 'document_id': 'doc-2b', \ +'source': None} +""", + """END of knowledge_search tool results. +""", + # Following metadata contains an intentionally incorrect keyword "Title" (instead of "title") + # and it is not picked as a referenced document. + """Result 3 +Content: ABC +Metadata: {'docs_url': 'https://example.com/doc3', 'Title': 'Doc3', 'document_id': 'doc-3', \ +'source': None} +""", + """The above results were retrieved to help answer the user\'s query: "Sample Query". +Use them as supporting information only in answering this query. +""", +] + def mock_database_operations(mocker): """Helper function to mock database operations for query endpoints.""" @@ -70,10 +110,6 @@ def setup_configuration_fixture(): async def test_query_endpoint_handler_configuration_not_loaded(mocker): """Test the query endpoint handler if configuration is not loaded.""" # simulate state when no configuration is loaded - mocker.patch( - "app.endpoints.query.configuration", - return_value=mocker.Mock(), - ) mocker.patch("app.endpoints.query.configuration", None) query = "What is OpenStack?" @@ -128,7 +164,7 @@ async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(llm_response, conversation_id), + return_value=(llm_response, conversation_id, []), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -185,15 +221,74 @@ async def test_query_endpoint_handler_store_transcript(mocker): await _test_query_endpoint_handler(mocker, store_transcript_to_file=True) +@pytest.mark.asyncio +async def test_query_endpoint_handler_with_referenced_documents(mocker): + """Test the query endpoint handler returns referenced documents.""" + mock_metric = mocker.patch("metrics.llm_calls_total") + mock_client = mocker.AsyncMock() + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") + mock_lsc.return_value = mock_client + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"), + ] + + mock_config = mocker.Mock() + mock_config.user_data_collection_configuration.transcripts_enabled = False + mocker.patch("app.endpoints.query.configuration", mock_config) + + llm_response = "LLM answer with referenced documents" + conversation_id = "fake_conversation_id" + referenced_documents = [ + ReferencedDocument(doc_url="https://example.com/doc1", doc_title="Doc1"), + ReferencedDocument(doc_url="https://example.com/doc2", doc_title="Doc2"), + ] + query = "What is OpenStack?" + + mocker.patch( + "app.endpoints.query.retrieve_response", + return_value=(llm_response, conversation_id, referenced_documents), + ) + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), + ) + mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + + # Mock database operations + mock_database_operations(mocker) + + query_request = QueryRequest(query=query) + + response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) + + # Assert the response contains referenced documents + assert response.response == llm_response + assert response.conversation_id == conversation_id + # Avoid brittle equality on Pydantic models; compare fields instead + assert [(str(d.doc_url), d.doc_title) for d in response.referenced_documents] == [ + ("https://example.com/doc1", "Doc1"), + ("https://example.com/doc2", "Doc2"), + ] + assert all(isinstance(d, ReferencedDocument) for d in response.referenced_documents) + assert len(response.referenced_documents) == 2 + # Titles should be sorted deterministically by doc_title + assert [d.doc_title for d in response.referenced_documents] == sorted( + [d.doc_title for d in response.referenced_documents] + ) + + # Assert the metric for successful LLM calls is incremented + mock_metric.labels("fake_provider_id", "fake_model_id").inc.assert_called_once() + + def test_select_model_and_provider_id_from_request(mocker): """Test the select_model_and_provider_id function.""" mocker.patch( - "metrics.utils.configuration.inference.default_provider", + "app.endpoints.query.configuration.inference.default_provider", "default_provider", ) mocker.patch( - "metrics.utils.configuration.inference.default_model", - "default_model", + "app.endpoints.query.configuration.inference.default_model", "default_model" ) model_list = [ @@ -228,12 +323,11 @@ def test_select_model_and_provider_id_from_request(mocker): def test_select_model_and_provider_id_from_configuration(mocker): """Test the select_model_and_provider_id function.""" mocker.patch( - "metrics.utils.configuration.inference.default_provider", + "app.endpoints.query.configuration.inference.default_provider", "default_provider", ) mocker.patch( - "metrics.utils.configuration.inference.default_model", - "default_model", + "app.endpoints.query.configuration.inference.default_model", "default_model" ) model_list = [ @@ -407,12 +501,13 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - response, _ = await retrieve_response( + response, _, referenced_documents = await retrieve_response( mock_client, model_id, query_request, access_token ) # fallback mechanism: check that the response is empty assert response == "" + assert referenced_documents == [] @pytest.mark.asyncio @@ -438,12 +533,13 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo model_id = "fake_model_id" access_token = "test_token" - response, _ = await retrieve_response( + response, _, referenced_documents = await retrieve_response( mock_client, model_id, query_request, access_token ) # fallback mechanism: check that the response is empty assert response == "" + assert referenced_documents == [] @pytest.mark.asyncio @@ -470,7 +566,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, referenced_documents = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -478,6 +574,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker mock_metric.inc.assert_not_called() assert response == "LLM answer" assert conversation_id == "fake_conversation_id" + assert referenced_documents == [] # No knowledge_search in this test, so empty list mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -487,6 +584,66 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker ) +@pytest.mark.asyncio +async def test_retrieve_response_with_knowledge_search_extracts_referenced_documents( + prepare_agent_mocks, mocker +): + """Test the retrieve_response function extracts referenced documents from knowledge_search.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client.shields.list.return_value = [] + mock_client.vector_dbs.list.return_value = [] + + # Mock the response with tool execution steps containing knowledge_search results + mock_tool_response = ToolResponse( + call_id="c1", + tool_name="knowledge_search", + content=[ + TextContentItem(text=s, type="text") + for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS + ], + ) + + mock_tool_execution_step = mocker.Mock() + mock_tool_execution_step.step_type = "tool_execution" + mock_tool_execution_step.tool_responses = [mock_tool_response] + + mock_agent.create_turn.return_value.steps = [mock_tool_execution_step] + + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.query.configuration", mock_config) + mocker.patch( + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), + ) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + access_token = "test_token" + + response, conversation_id, referenced_documents = await retrieve_response( + mock_client, model_id, query_request, access_token + ) + + assert response == "LLM answer" + assert conversation_id == "fake_conversation_id" + + # Assert referenced documents were extracted correctly + assert len(referenced_documents) == 3 + assert str(referenced_documents[0].doc_url) == "https://example.com/doc1" + assert referenced_documents[0].doc_title == "Doc1" + assert str(referenced_documents[1].doc_url) == "https://example.com/doc2" + assert referenced_documents[1].doc_title == "Doc2" + assert str(referenced_documents[2].doc_url) == "https://example.com/doc2b" + assert referenced_documents[2].doc_title == "Doc2b" + + # Doc3 should not be included because it has "Title" instead of "title" + doc_titles = [doc.doc_title for doc in referenced_documents] + assert "Doc3" not in doc_titles + + @pytest.mark.asyncio async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" @@ -508,7 +665,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -557,7 +714,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -609,7 +766,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -663,7 +820,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -719,7 +876,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -773,7 +930,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -828,7 +985,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token_123" - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -897,7 +1054,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( model_id = "fake_model_id" access_token = "" # Empty token - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -955,8 +1112,6 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( ) query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "" mcp_headers = { "filesystem-server": {"Authorization": "Bearer test_token_123"}, "git-server": {"Authorization": "Bearer test_token_456"}, @@ -968,11 +1123,11 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( }, } - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, - model_id, + "fake_model_id", query_request, - access_token, + "", mcp_headers=mcp_headers, ) @@ -982,7 +1137,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( mock_client, - model_id, + "fake_model_id", mocker.ANY, # system_prompt [], # available_input_shields [], # available_output_shields @@ -990,20 +1145,20 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( False, # no_tools ) - expected_mcp_headers = { - "http://localhost:3000": {"Authorization": "Bearer test_token_123"}, - "https://git.example.com/mcp": {"Authorization": "Bearer test_token_456"}, - "http://another-server-mcp-server:3000": { - "Authorization": "Bearer test_token_789" - }, - # we do not put "unknown-mcp-server" url as it's unknown to lightspeed-stack - } - # Check that the agent's extra_headers property was set correctly expected_extra_headers = { "X-LlamaStack-Provider-Data": json.dumps( { - "mcp_headers": expected_mcp_headers, + "mcp_headers": { + "http://localhost:3000": {"Authorization": "Bearer test_token_123"}, + "https://git.example.com/mcp": { + "Authorization": "Bearer test_token_456" + }, + "http://another-server-mcp-server:3000": { + "Authorization": "Bearer test_token_789" + }, + # we do not put "unknown-mcp-server" url as it's unknown to lightspeed-stack + }, } ) } @@ -1049,7 +1204,7 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): query_request = QueryRequest(query="What is OpenStack?") - _, conversation_id = await retrieve_response( + _, conversation_id, _ = await retrieve_response( mock_client, "fake_model_id", query_request, "test_token" ) @@ -1201,7 +1356,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): mock_retrieve_response = mocker.patch( "app.endpoints.query.retrieve_response", - return_value=("test response", "test_conversation_id"), + return_value=("test response", "test_conversation_id", []), ) mocker.patch( @@ -1232,7 +1387,6 @@ async def test_query_endpoint_handler_no_tools_true(mocker): ] mock_config = mocker.Mock() - mock_config.user_data_collection_configuration.transcripts_disabled = True mocker.patch("app.endpoints.query.configuration", mock_config) llm_response = "LLM answer without tools" @@ -1241,7 +1395,7 @@ async def test_query_endpoint_handler_no_tools_true(mocker): mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(llm_response, conversation_id), + return_value=(llm_response, conversation_id, []), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -1271,7 +1425,6 @@ async def test_query_endpoint_handler_no_tools_false(mocker): ] mock_config = mocker.Mock() - mock_config.user_data_collection_configuration.transcripts_disabled = True mocker.patch("app.endpoints.query.configuration", mock_config) llm_response = "LLM answer with tools" @@ -1280,7 +1433,7 @@ async def test_query_endpoint_handler_no_tools_false(mocker): mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(llm_response, conversation_id), + return_value=(llm_response, conversation_id, []), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -1329,7 +1482,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -1379,7 +1532,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -1505,3 +1658,507 @@ def test_evaluate_model_hints( assert provider_id == expected_provider assert model_id == expected_model + + +def test_process_knowledge_search_content_with_valid_metadata(mocker): + """Test process_knowledge_search_content with valid metadata.""" + # Mock tool response with valid metadata + text_content_item = mocker.Mock() + text_content_item.text = """Result 1 +Content: Test content +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Test Doc', 'document_id': 'doc-1'} +""" + + tool_response = mocker.Mock() + tool_response.content = [text_content_item] + + metadata_map = process_knowledge_search_content(tool_response) + + # Verify metadata was correctly parsed and added + assert "doc-1" in metadata_map + assert metadata_map["doc-1"]["docs_url"] == "https://example.com/doc1" + assert metadata_map["doc-1"]["title"] == "Test Doc" + assert metadata_map["doc-1"]["document_id"] == "doc-1" + + +def test_process_knowledge_search_content_with_invalid_metadata_syntax_error(mocker): + """Test process_knowledge_search_content gracefully handles SyntaxError.""" + # Mock tool response with invalid metadata (invalid Python syntax) + text_content_item = mocker.Mock() + text_content_item.text = """Result 1 +Content: Test content +Metadata: {'docs_url': 'https://example.com/doc1' 'title': 'Test Doc', 'document_id': 'doc-1'} +""" # Missing comma between 'doc1' and 'title' - will cause SyntaxError + + tool_response = mocker.Mock() + tool_response.content = [text_content_item] + + metadata_map = process_knowledge_search_content(tool_response) + + # Verify metadata_map remains empty due to invalid syntax (gracefully handled) + assert len(metadata_map) == 0 + + +def test_process_knowledge_search_content_with_invalid_metadata_value_error(mocker): + """Test process_knowledge_search_content gracefully handles ValueError from invalid metadata.""" + # Mock tool response with invalid metadata containing complex expressions + text_content_item = mocker.Mock() + text_content_item.text = """Result 1 +Content: Test content +Metadata: {func_call(): 'value', 'title': 'Test Doc', 'document_id': 'doc-1'} +""" # Function call in dict - will cause ValueError since it's not a literal + + tool_response = mocker.Mock() + tool_response.content = [text_content_item] + + metadata_map = process_knowledge_search_content(tool_response) + + # Verify metadata_map remains empty due to invalid expression (gracefully handled) + assert len(metadata_map) == 0 + + +def test_process_knowledge_search_content_with_non_dict_metadata(mocker): + """Test process_knowledge_search_content handles non-dict metadata gracefully.""" + mock_logger = mocker.patch("utils.metadata.logger") + + # Mock tool response with metadata that's not a dict + text_content_item = mocker.Mock() + text_content_item.text = """Result 1 +Content: Test content +Metadata: "just a string" +""" + + tool_response = mocker.Mock() + tool_response.content = [text_content_item] + + metadata_map = process_knowledge_search_content(tool_response) + + # Verify metadata_map remains empty (no document_id in string) + assert len(metadata_map) == 0 + + # No exception should be logged since string is a valid literal and simply ignored + mock_logger.exception.assert_not_called() + + +def test_process_knowledge_search_content_with_metadata_missing_document_id(mocker): + """Test process_knowledge_search_content skips metadata without document_id.""" + # Mock tool response with valid metadata but missing document_id + text_content_item = mocker.Mock() + text_content_item.text = """Result 1 +Content: Test content +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Test Doc'} +""" # No document_id field + + tool_response = mocker.Mock() + tool_response.content = [text_content_item] + + metadata_map = process_knowledge_search_content(tool_response) + + # Verify metadata_map remains empty since document_id is missing + assert len(metadata_map) == 0 + + +def test_process_knowledge_search_content_with_no_text_attribute(mocker): + """Test process_knowledge_search_content skips content items without text attribute.""" + # Mock tool response with content item that has no text attribute + text_content_item = mocker.Mock(spec=[]) # spec=[] means no attributes + + tool_response = mocker.Mock() + tool_response.content = [text_content_item] + + metadata_map = process_knowledge_search_content(tool_response) + + # Verify metadata_map remains empty since text attribute is missing + assert len(metadata_map) == 0 + + +def test_process_knowledge_search_content_with_none_content(mocker): + """Test process_knowledge_search_content handles tool_response with content=None.""" + # Mock tool response with content = None + tool_response = mocker.Mock() + tool_response.content = None + + metadata_map = process_knowledge_search_content(tool_response) + + # Verify metadata_map remains empty when content is None + assert len(metadata_map) == 0 + + +def test_process_knowledge_search_content_duplicate_document_id_last_wins(mocker): + """The last metadata block for a given document_id should win.""" + text_items = [ + mocker.Mock( + text="Metadata: {'docs_url': 'https://example.com/first', " + "'title': 'First', 'document_id': 'doc-x'}" + ), + mocker.Mock( + text="Metadata: {'docs_url': 'https://example.com/second', " + "'title': 'Second', 'document_id': 'doc-x'}" + ), + ] + tool_response = mocker.Mock() + tool_response.tool_name = "knowledge_search" + tool_response.content = text_items + + # Process content + metadata_map = process_knowledge_search_content(tool_response) + assert metadata_map["doc-x"]["docs_url"] == "https://example.com/second" + assert metadata_map["doc-x"]["title"] == "Second" + + # Ensure extraction reflects last-wins as well + step = mocker.Mock() + step.step_type = "tool_execution" + step.tool_responses = [tool_response] + docs = extract_referenced_documents_from_steps([step]) + assert len(docs) == 1 + assert str(docs[0].doc_url) == "https://example.com/second" + assert docs[0].doc_title == "Second" + + +def test_process_knowledge_search_content_with_braces_inside_strings(mocker): + """Test that braces inside strings are handled correctly.""" + text_content_item = mocker.Mock() + text_content_item.text = ( + "Result 1\n" + "Content: Test content\n" + "Metadata: {'document_id': 'doc-100', 'title': 'A {weird} title', " + "'docs_url': 'https://example.com/100', 'extra': {'note': 'contains {braces}'}}" + ) + tool_response = mocker.Mock() + tool_response.content = [text_content_item] + + metadata_map = process_knowledge_search_content(tool_response) + assert "doc-100" in metadata_map + assert metadata_map["doc-100"]["title"] == "A {weird} title" + assert metadata_map["doc-100"]["docs_url"] == "https://example.com/100" + assert metadata_map["doc-100"]["extra"]["note"] == "contains {braces}" + + +def test_process_knowledge_search_content_with_nested_objects(mocker): + """Test that nested objects are parsed correctly.""" + text_content_item = mocker.Mock() + text_content_item.text = ( + "Result 1\n" + "Content: Test content\n" + 'Metadata: {"document_id": "doc-200", "title": "Nested JSON", ' + '"docs_url": "https://example.com/200", "meta": {"k": {"inner": 1}}}' + ) + tool_response = mocker.Mock() + tool_response.content = [text_content_item] + + metadata_map = process_knowledge_search_content(tool_response) + assert "doc-200" in metadata_map + assert metadata_map["doc-200"]["title"] == "Nested JSON" + assert metadata_map["doc-200"]["docs_url"] == "https://example.com/200" + assert metadata_map["doc-200"]["meta"]["k"]["inner"] == 1 + + +def test_process_knowledge_search_content_with_string_fallback_parsing(mocker): + """Test that string content uses parse_knowledge_search_metadata as fallback.""" + # Create a tool response with string content containing metadata + string_content = """Result 1 +Content: Test content +Metadata: {'docs_url': 'https://example.com/fallback', 'title': 'Fallback Doc', 'document_id': 'fallback-1'} + +Result 2 +Content: More content +Metadata: {'docs_url': 'https://example.com/fallback2', 'title': 'Fallback Doc 2', 'document_id': 'fallback-2'} +""" + + tool_response = mocker.Mock() + tool_response.content = string_content # String instead of list + + metadata_map = process_knowledge_search_content(tool_response) + + # Verify fallback parsing worked correctly + assert len(metadata_map) == 2 + assert "fallback-1" in metadata_map + assert "fallback-2" in metadata_map + assert metadata_map["fallback-1"]["title"] == "Fallback Doc" + assert metadata_map["fallback-1"]["docs_url"] == "https://example.com/fallback" + assert metadata_map["fallback-2"]["title"] == "Fallback Doc 2" + assert metadata_map["fallback-2"]["docs_url"] == "https://example.com/fallback2" + + +def test_process_knowledge_search_content_metadata_label_case_insensitive(mocker): + """Test that metadata labels are detected case-insensitively.""" + text_content_item = mocker.Mock() + text_content_item.text = ( + "Result 1\n" + "Content: Test content\n" + "metadata: {'document_id': 'doc-ci', 'title': 'Case Insensitive', " + "'docs_url': 'https://example.com/ci'}\n" + ) + tool_response = mocker.Mock() + tool_response.content = [text_content_item] + + metadata_map = process_knowledge_search_content(tool_response) + + assert "doc-ci" in metadata_map + assert metadata_map["doc-ci"]["title"] == "Case Insensitive" + assert metadata_map["doc-ci"]["docs_url"] == "https://example.com/ci" + + +@pytest.mark.asyncio +async def test_retrieve_response_with_none_content(prepare_agent_mocks, mocker): + """Test retrieve_response handles None content gracefully.""" + mock_client, mock_agent = prepare_agent_mocks + + # Mock response with None content + mock_response = mocker.Mock() + mock_response.output_message.content = None + mock_response.steps = [] + mock_agent.create_turn.return_value = mock_response + + mock_client.shields.list.return_value = [] + mock_client.vector_dbs.list.return_value = [] + + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.query.configuration", mock_config) + mocker.patch( + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), + ) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + access_token = "test_token" + + response, conversation_id, _ = await retrieve_response( + mock_client, model_id, query_request, access_token + ) + + # Should return empty string instead of "None" + assert response == "" + assert conversation_id == "fake_conversation_id" + + +@pytest.mark.asyncio +async def test_retrieve_response_with_missing_output_message( + prepare_agent_mocks, mocker +): + """Test retrieve_response handles missing output_message gracefully.""" + mock_client, mock_agent = prepare_agent_mocks + + # Mock response without output_message attribute + mock_response = mocker.Mock(spec=["steps"]) # Only has steps attribute + mock_response.steps = [] + mock_agent.create_turn.return_value = mock_response + + mock_client.shields.list.return_value = [] + mock_client.vector_dbs.list.return_value = [] + + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.query.configuration", mock_config) + mocker.patch( + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), + ) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + access_token = "test_token" + + response, conversation_id, _ = await retrieve_response( + mock_client, model_id, query_request, access_token + ) + + # Should return empty string when output_message is missing + assert response == "" + assert conversation_id == "fake_conversation_id" + + +@pytest.mark.asyncio +async def test_retrieve_response_with_missing_content_attribute( + prepare_agent_mocks, mocker +): + """Test retrieve_response handles missing content attribute gracefully.""" + mock_client, mock_agent = prepare_agent_mocks + + # Mock response with output_message but no content attribute + mock_response = mocker.Mock() + mock_response.output_message = mocker.Mock(spec=[]) # No content attribute + mock_response.steps = [] + mock_agent.create_turn.return_value = mock_response + + mock_client.shields.list.return_value = [] + mock_client.vector_dbs.list.return_value = [] + + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.query.configuration", mock_config) + mocker.patch( + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), + ) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + access_token = "test_token" + + response, conversation_id, _ = await retrieve_response( + mock_client, model_id, query_request, access_token + ) + + # Should return empty string when content attribute is missing + assert response == "" + assert conversation_id == "fake_conversation_id" + + +@pytest.mark.asyncio +async def test_retrieve_response_with_structured_content_object( + prepare_agent_mocks, mocker +): + """Test retrieve_response handles structured content objects properly.""" + mock_client, mock_agent = prepare_agent_mocks + + # Mock response with a structured content object + structured_content = {"type": "text", "value": "This is structured content"} + mock_response = mocker.Mock() + mock_response.output_message.content = structured_content + mock_response.steps = [] + mock_agent.create_turn.return_value = mock_response + + mock_client.shields.list.return_value = [] + mock_client.vector_dbs.list.return_value = [] + + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.query.configuration", mock_config) + mocker.patch( + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), + ) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + access_token = "test_token" + + response, conversation_id, _ = await retrieve_response( + mock_client, model_id, query_request, access_token + ) + + # Should convert the structured object to string representation + assert response == str(structured_content) + assert conversation_id == "fake_conversation_id" + + +@pytest.mark.asyncio +async def test_retrieve_response_skips_invalid_docs_url(prepare_agent_mocks, mocker): + """Test that retrieve_response skips entries with invalid docs_url.""" + # Mock logger to capture warning logs + mock_logger = mocker.patch("utils.metadata.logger") + + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client.shields.list.return_value = [] + mock_client.vector_dbs.list.return_value = [] + + # Mock tool response with valid and invalid docs_url entries + invalid_docs_url_results = [ + """knowledge_search tool found 2 chunks: +BEGIN of knowledge_search tool results. +""", + """Result 1 +Content: Valid content +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Valid Doc', 'document_id': 'doc-1'} +""", + """Result 2 +Content: Invalid content +Metadata: {'docs_url': 'not-a-valid-url', 'title': 'Invalid Doc', 'document_id': 'doc-2'} +""", + """END of knowledge_search tool results. +""", + ] + + mock_tool_response = mocker.Mock() + mock_tool_response.call_id = "c1" + mock_tool_response.tool_name = "knowledge_search" + mock_tool_response.content = [ + TextContentItem(text=s, type="text") for s in invalid_docs_url_results + ] + + mock_tool_execution_step = mocker.Mock() + mock_tool_execution_step.step_type = "tool_execution" + mock_tool_execution_step.tool_responses = [mock_tool_response] + + mock_agent.create_turn.return_value.steps = [mock_tool_execution_step] + + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.query.configuration", mock_config) + mocker.patch( + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), + ) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + access_token = "test_token" + + response, conversation_id, referenced_documents = await retrieve_response( + mock_client, model_id, query_request, access_token + ) + + assert response == "LLM answer" + assert conversation_id == "fake_conversation_id" + + # Assert only the valid document is included, invalid one is skipped + assert len(referenced_documents) == 1 + assert str(referenced_documents[0].doc_url) == "https://example.com/doc1" + assert referenced_documents[0].doc_title == "Valid Doc" + + # Ensure we logged a warning for the invalid docs_url + assert any( + call[0][0].startswith("Skipping invalid referenced document") + or "Skipping invalid referenced document" in str(call) + for call in mock_logger.warning.call_args_list + ) + # Verify the bad URL is included in the log message for extra confidence + assert any( + "not-a-valid-url" in str(call) for call in mock_logger.warning.call_args_list + ) + + +@pytest.mark.asyncio +async def test_extract_referenced_documents_from_steps_handles_validation_errors( + mocker, +): + """Test that extract_referenced_documents_from_steps handles validation errors gracefully.""" + # Mock tool response with invalid docs_url that will cause pydantic validation error + mock_tool_response = mocker.Mock() + mock_tool_response.tool_name = "knowledge_search" + mock_tool_response.content = [ + mocker.Mock( + text="""Result 1 +Content: Valid content +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Valid Doc', 'document_id': 'doc-1'} +""" + ), + mocker.Mock( + text="""Result 2 +Content: Invalid content +Metadata: {'docs_url': 'invalid-url', 'title': 'Invalid Doc', 'document_id': 'doc-2'} +""" + ), + ] + + mock_tool_execution_step = mocker.Mock() + mock_tool_execution_step.step_type = "tool_execution" + mock_tool_execution_step.tool_responses = [mock_tool_response] + + steps = [mock_tool_execution_step] + + referenced_documents = extract_referenced_documents_from_steps(steps) + + # Should only return the valid document, skipping the invalid one + assert len(referenced_documents) == 1 + assert str(referenced_documents[0].doc_url) == "https://example.com/doc1" + assert referenced_documents[0].doc_title == "Valid Doc" diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 8e03aa9c..94678af6 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -40,6 +40,7 @@ streaming_query_endpoint_handler, retrieve_response, stream_build_event, + stream_end_event, ) from models.requests import QueryRequest, Attachment @@ -996,6 +997,149 @@ def test_stream_build_event_step_complete(): assert '"id": 0' in result +def test_stream_build_event_knowledge_search_with_invalid_metadata(mocker): + """Test stream_build_event handles invalid metadata in knowledge_search tool response.""" + mock_logger = mocker.patch("app.endpoints.streaming_query.logger") + + # Create a mock chunk with knowledge_search tool response containing invalid metadata + chunk = AgentTurnResponseStreamChunk( + event=TurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + event_type="step_complete", + step_id="s1", + step_type="tool_execution", + step_details=ToolExecutionStep( + turn_id="t1", + step_id="s2", + step_type="tool_execution", + tool_responses=[ + ToolResponse( + call_id="c1", + tool_name="knowledge_search", + content=[ + TextContentItem( + text="""Result 1 +Content: Test content +Metadata: {'docs_url': 'https://example.com/doc1' 'title': 'Test Doc', 'document_id': 'doc-1'} +""", # Missing comma - invalid syntax + type="text", + ) + ], + ) + ], + tool_calls=[ + ToolCall( + call_id="t1", tool_name="knowledge_search", arguments={} + ) + ], + ), + ) + ) + ) + + metadata_map = {} + result_list = list(stream_build_event(chunk, 0, metadata_map)) + + # Verify metadata_map remains empty due to invalid metadata + assert len(metadata_map) == 0 + + # Verify the function still returns tool execution events + assert len(result_list) == 2 # One for tool_calls, one for tool_responses + + # Verify no exception logging was called in non-strict mode + mock_logger.exception.assert_not_called() + + +def test_stream_end_event_with_referenced_documents(): + """Test stream_end_event creates proper JSON with ReferencedDocument validation.""" + metadata_map = { + "doc-1": { + "docs_url": "https://example.com/doc1", + "title": "Test Document 1", + "document_id": "doc-1", + }, + "doc-2": { + "docs_url": "https://example.com/doc2", + "title": "Test Document 2", + "document_id": "doc-2", + }, + "doc-3": { + # Missing title - should be filtered out + "docs_url": "https://example.com/doc3", + "document_id": "doc-3", + }, + "doc-4": { + # Missing docs_url - should be filtered out + "title": "Test Document 4", + "document_id": "doc-4", + }, + } + + result = stream_end_event(metadata_map) + + # Parse the JSON response + parsed = json.loads(result.replace("data: ", "")) + + # Verify structure + assert parsed["event"] == "end" + assert "referenced_documents" in parsed["data"] + + # Verify only valid documents are included + referenced_docs = parsed["data"]["referenced_documents"] + assert len(referenced_docs) == 2 + + # Verify document structure and URL validation + doc_urls = [doc["doc_url"] for doc in referenced_docs] + doc_titles = [doc["doc_title"] for doc in referenced_docs] + + assert "https://example.com/doc1" in doc_urls + assert "https://example.com/doc2" in doc_urls + assert "Test Document 1" in doc_titles + assert "Test Document 2" in doc_titles + + # Verify filtered documents are not included + assert "https://example.com/doc3" not in doc_urls + assert "Test Document 4" not in doc_titles + + +def test_stream_end_event_skips_invalid_docs_url(): + """Test stream_end_event skips entries with invalid docs_url.""" + metadata_map = { + "doc-1": { + "docs_url": "https://example.com/doc1", + "title": "Valid Document", + "document_id": "doc-1", + }, + "doc-2": { + "docs_url": "not-a-valid-url", # Invalid URL that will cause ValidationError + "title": "Invalid Document", + "document_id": "doc-2", + }, + "doc-3": { + "docs_url": "", # Empty URL that will cause ValidationError + "title": "Empty URL Document", + "document_id": "doc-3", + }, + } + + result = stream_end_event(metadata_map) + + # Parse the JSON response + parsed = json.loads(result.replace("data: ", "")) + + # Verify structure + assert parsed["event"] == "end" + assert "referenced_documents" in parsed["data"] + + # Verify only valid documents are included, invalid ones are skipped + referenced_docs = parsed["data"]["referenced_documents"] + assert len(referenced_docs) == 1 + + # Verify the valid document is included + assert referenced_docs[0]["doc_url"] == "https://example.com/doc1" + assert referenced_docs[0]["doc_title"] == "Valid Document" + + def test_stream_build_event_error(): """Test stream_build_event function returns a 'error' when chunk contains error information.""" # Create a mock chunk without an expected payload structure diff --git a/tests/unit/utils/test_metadata.py b/tests/unit/utils/test_metadata.py new file mode 100644 index 00000000..40f2d43a --- /dev/null +++ b/tests/unit/utils/test_metadata.py @@ -0,0 +1,374 @@ +"""Unit tests for utils.metadata module.""" + +import pytest + +from utils.metadata import parse_knowledge_search_metadata, METADATA_LABEL_PATTERN + + +def test_metadata_pattern_exists(): + """Test that METADATA_LABEL_PATTERN is properly defined and captures labels correctly.""" + assert METADATA_LABEL_PATTERN is not None + assert hasattr(METADATA_LABEL_PATTERN, "finditer") + + # Test that the pattern captures metadata labels case-insensitively + sample = "Foo\nMetadata: {'a': 1}\nMETADATA: {'b': 2}\nBar" + matches = list(METADATA_LABEL_PATTERN.finditer(sample)) + assert len(matches) == 2 + + # Check that the matches are at the expected positions + assert sample[matches[0].start() : matches[0].end()] == "Metadata: " + assert sample[matches[1].start() : matches[1].end()] == "METADATA: " + + +def test_parse_knowledge_search_metadata_valid_single(): + """Test parsing valid metadata with single entry.""" + text = """Result 1 +Content: Some content +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1', 'document_id': 'doc-1'} +""" + result = parse_knowledge_search_metadata(text) + + assert len(result) == 1 + assert "doc-1" in result + assert result["doc-1"]["docs_url"] == "https://example.com/doc1" + assert result["doc-1"]["title"] == "Doc1" + assert result["doc-1"]["document_id"] == "doc-1" + + +def test_parse_knowledge_search_metadata_valid_multiple(): + """Test parsing valid metadata with multiple entries.""" + text = """Result 1 +Content: Some content +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1', 'document_id': 'doc-1'} + +Result 2 +Content: More content +Metadata: {'docs_url': 'https://example.com/doc2', 'title': 'Doc2', 'document_id': 'doc-2'} +""" + result = parse_knowledge_search_metadata(text) + + assert len(result) == 2 + assert "doc-1" in result + assert "doc-2" in result + assert result["doc-1"]["title"] == "Doc1" + assert result["doc-2"]["title"] == "Doc2" + + +def test_parse_knowledge_search_metadata_no_metadata(): + """Test parsing text with no metadata.""" + text = """Result 1 +Content: Some content without metadata +""" + result = parse_knowledge_search_metadata(text) + + assert len(result) == 0 + + +def test_parse_knowledge_search_metadata_missing_document_id(): + """Test parsing metadata without document_id is ignored.""" + text = """Result 1 +Content: Some content +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1'} +""" + result = parse_knowledge_search_metadata(text) + + assert len(result) == 0 + + +def test_parse_knowledge_search_metadata_malformed_literal(): + """Test parsing malformed Python literal raises ValueError.""" + text = """Result 1 +Content: Some content +Metadata: {'docs_url': 'https://example.com/doc1' 'title': 'Doc1', 'document_id': 'doc-1'} +""" + with pytest.raises(ValueError) as exc_info: + parse_knowledge_search_metadata(text) + + assert "Failed to parse metadata" in str(exc_info.value) + + +def test_parse_knowledge_search_metadata_invalid_syntax(): + """Test parsing invalid Python syntax raises ValueError.""" + text = """Result 1 +Content: Some content +Metadata: {func_call(): 'value', 'title': 'Doc1', 'document_id': 'doc-1'} +""" + with pytest.raises(ValueError) as exc_info: + parse_knowledge_search_metadata(text) + + assert "Failed to parse metadata" in str(exc_info.value) + + +def test_parse_knowledge_search_metadata_non_dict(): + """Test parsing non-dict metadata is ignored.""" + text = """Result 1 +Content: Some content +Metadata: "just a string" +""" + result = parse_knowledge_search_metadata(text) + + assert len(result) == 0 + + +def test_parse_knowledge_search_metadata_mixed_valid_invalid(): + """Test parsing text with both valid and invalid metadata.""" + text = """Result 1 +Content: Some content +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1', 'document_id': 'doc-1'} + +Result 2 +Content: Bad content +Metadata: {'docs_url': 'https://example.com/doc2' 'title': 'Doc2', 'document_id': 'doc-2'} +""" + with pytest.raises(ValueError) as exc_info: + parse_knowledge_search_metadata(text) + + assert "Failed to parse metadata" in str(exc_info.value) + + +def test_parse_knowledge_search_metadata_whitespace_handling(): + """Test parsing metadata with various whitespace patterns.""" + text = """Result 1 +Content: Some content + Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1', 'document_id': 'doc-1'} +""" + result = parse_knowledge_search_metadata(text) + + assert len(result) == 1 + assert "doc-1" in result + assert result["doc-1"]["title"] == "Doc1" + + +def test_parse_metadata_duplicate_document_id_last_wins(): + """Test that duplicate document_id entries overwrite (last wins).""" + text = ( + "Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1a', " + "'document_id': 'dup'}\n" + "Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1b', " + "'document_id': 'dup'}" + ) + result = parse_knowledge_search_metadata(text) + + assert len(result) == 1 + assert set(result.keys()) == {"dup"} + assert result["dup"]["title"] == "Doc1b" + + +def test_parse_knowledge_search_metadata_non_strict_mode(): + """Test non-strict mode skips invalid blocks and continues parsing.""" + text = """Result 1 +Content: Valid content +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1', 'document_id': 'doc-1'} + +Result 2 +Content: Bad content +Metadata: {'docs_url': 'https://example.com/doc2' 'title': 'Doc2', 'document_id': 'doc-2'} + +Result 3 +Content: More valid content +Metadata: {'docs_url': 'https://example.com/doc3', 'title': 'Doc3', 'document_id': 'doc-3'} +""" + result = parse_knowledge_search_metadata(text, strict=False) + + # Should have 2 valid documents, skipping the malformed one + assert len(result) == 2 + assert "doc-1" in result + assert "doc-3" in result + assert "doc-2" not in result # malformed entry should be skipped + assert result["doc-1"]["title"] == "Doc1" + assert result["doc-3"]["title"] == "Doc3" + + +def test_parse_knowledge_search_metadata_strict_mode_default(): + """Test that strict mode is the default behavior.""" + text = """Result 1 +Content: Valid content +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1', 'document_id': 'doc-1'} + +Result 2 +Content: Bad content +Metadata: {'docs_url': 'https://example.com/doc2' 'title': 'Doc2', 'document_id': 'doc-2'} +""" + # Should raise ValueError in strict mode (default) + with pytest.raises(ValueError) as exc_info: + parse_knowledge_search_metadata(text) + + assert "Failed to parse metadata" in str(exc_info.value) + + # Explicitly setting strict=True should behave the same + with pytest.raises(ValueError) as exc_info: + parse_knowledge_search_metadata(text, strict=True) + + assert "Failed to parse metadata" in str(exc_info.value) + + +def test_metadata_pattern_case_insensitive_and_nested(): + """Test case-insensitive matching and nested payloads.""" + text = """Result +Content +METADATA: {'document_id': 'doc-1', 'nested': {'k': [1, 2, 3]}, 'title': 'Nested Doc'} +Another result +metadata: {'document_id': 'doc-2', 'complex': {'a': {'b': {'c': 42}}, 'list': [{'x': 1}, {'y': 2}]}, 'title': 'Complex Doc'} +""" + result = parse_knowledge_search_metadata(text) + + assert len(result) == 2 + assert "doc-1" in result + assert "doc-2" in result + + # Verify the nested structure was parsed correctly + assert result["doc-1"]["nested"]["k"] == [1, 2, 3] + assert result["doc-1"]["title"] == "Nested Doc" + + assert result["doc-2"]["complex"]["a"]["b"]["c"] == 42 + assert result["doc-2"]["complex"]["list"][0]["x"] == 1 + assert result["doc-2"]["complex"]["list"][1]["y"] == 2 + assert result["doc-2"]["title"] == "Complex Doc" + + +def test_metadata_pattern_various_case_variations(): + """Test different case variations of metadata label.""" + text = """ +Metadata: {'document_id': 'doc-1', 'title': 'Standard'} +METADATA: {'document_id': 'doc-2', 'title': 'Uppercase'} +metadata: {'document_id': 'doc-3', 'title': 'Lowercase'} +MetaData: {'document_id': 'doc-4', 'title': 'Mixed Case'} +""" + result = parse_knowledge_search_metadata(text) + + assert len(result) == 4 + assert result["doc-1"]["title"] == "Standard" + assert result["doc-2"]["title"] == "Uppercase" + assert result["doc-3"]["title"] == "Lowercase" + assert result["doc-4"]["title"] == "Mixed Case" + + +def test_balanced_braces_with_nested_dicts_and_strings(): + """Test balanced brace parsing with complex nested structures.""" + text = """ +Metadata: {'document_id': 'doc-1', 'data': {'nested': 'value with {braces} in string'}, 'array': [{'inner': 'val'}]} +""" + result = parse_knowledge_search_metadata(text) + + assert len(result) == 1 + assert result["doc-1"]["data"]["nested"] == "value with {braces} in string" + assert result["doc-1"]["array"][0]["inner"] == "val" + + +def test_unmatched_braces_handling(): + """Test handling of unmatched braces in strict and non-strict modes.""" + text = """ +Metadata: {'document_id': 'doc-1', 'incomplete': 'missing brace' +Valid after: some text +Metadata: {'document_id': 'doc-2', 'title': 'Valid Doc'} +""" + # Strict mode should raise error + with pytest.raises(ValueError) as exc_info: + parse_knowledge_search_metadata(text, strict=True) + + assert "Failed to parse metadata" in str(exc_info.value) + + # Non-strict mode should skip the invalid entry and parse the valid one + result = parse_knowledge_search_metadata(text, strict=False) + assert len(result) == 1 + assert "doc-2" in result + assert result["doc-2"]["title"] == "Valid Doc" + + +def test_no_opening_brace_after_metadata_label(): + """Test handling when no opening brace follows metadata label.""" + text = """ +Metadata: not a dict +Some other content +Metadata: {'document_id': 'doc-1', 'title': 'Valid'} +""" + result = parse_knowledge_search_metadata(text) + + # Should only find the valid metadata entry + assert len(result) == 1 + assert "doc-1" in result + assert result["doc-1"]["title"] == "Valid" + + +@pytest.mark.parametrize( + "text, strict, expected_ids, description", + [ + # Valid cases + ( + "Metadata: {'document_id': 'a', 'title': 'Doc A'}", + True, + {"a"}, + "single valid metadata", + ), + ( + "Metadata: {'document_id': 'a', 'title': 'Doc A'}\n" + "Metadata: {'document_id': 'b', 'title': 'Doc B'}", + True, + {"a", "b"}, + "multiple valid metadata", + ), + ( + "METADATA: {'document_id': 'upper', 'title': 'Upper'}\n" + "metadata: {'document_id': 'lower', 'title': 'Lower'}", + True, + {"upper", "lower"}, + "case-insensitive labels", + ), + # Error handling - strict mode + ( + "Metadata: {'document_id': 'a', 'title': 'Doc A'}\n" + "Metadata: {'document_id': 'b' 'oops': 1}", + False, + {"a"}, + "malformed metadata skipped in non-strict mode", + ), + ( + "Metadata: not_a_dict\nMetadata: {'document_id': 'valid', 'title': 'Valid'}", + True, + {"valid"}, + "non-dict content ignored", + ), + # No metadata cases + ( + "Some text without metadata", + True, + set(), + "no metadata found", + ), + ( + "Metadata: {'title': 'No ID'}", + True, + set(), + "metadata without document_id ignored", + ), + ], +) +def test_parse_metadata_parametrized(text, strict, expected_ids, description): + """Parametrized test for various metadata parsing scenarios.""" + if strict and "malformed" in description: + # Should raise in strict mode for malformed content + with pytest.raises(ValueError): + parse_knowledge_search_metadata(text, strict=strict) + else: + result = parse_knowledge_search_metadata(text, strict=strict) + assert set(result.keys()) == expected_ids, f"Failed for: {description}" + + +@pytest.mark.parametrize( + "metadata_label, expected_matches", + [ + ("Metadata:", 1), + ("METADATA:", 1), + ("metadata:", 1), + ("MetaData:", 1), + ("MetaDaTa:", 1), + (" Metadata: ", 1), # with whitespace + ("NotMetadata:", 0), # should not match + ("metadata", 0), # missing colon + ], +) +def test_label_pattern_matching(metadata_label, expected_matches): + """Test that the label pattern matches various case variations correctly.""" + sample = f"Some text\n{metadata_label} {{'document_id': 'test'}}\nMore text" + matches = list(METADATA_LABEL_PATTERN.finditer(sample)) + assert len(matches) == expected_matches