From 06ad2d9e6766a8f98704008540f5ac1530ec94f3 Mon Sep 17 00:00:00 2001 From: JR Boos Date: Thu, 24 Jul 2025 13:08:46 -0400 Subject: [PATCH] Implement attachment parsing in conversation handling - Added a new function `parse_attachments_from_content` to extract attachment information from the content structure. - Updated the `simplify_session_data` function to include parsed attachments in the cleaned messages. - Modified the `retrieve_response` methods in `query.py` and `streaming_query.py` to format and include attachment information in the messages sent to the agent. --- src/app/endpoints/conversations.py | 62 +++++++++++++++++++++++++++- src/app/endpoints/query.py | 22 +++++++++- src/app/endpoints/streaming_query.py | 29 +++++++++++-- 3 files changed, 107 insertions(+), 6 deletions(-) diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index 90208997..33646ead 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -1,6 +1,7 @@ """Handler for REST API calls to manage conversation history.""" import logging +import re from typing import Any from llama_stack_client import APIConnectionError, NotFoundError @@ -20,6 +21,60 @@ conversation_id_to_agent_id: dict[str, str] = {} + +def parse_attachments_from_content( + content: list[dict[str, str]], +) -> list[dict[str, str]] | None: + """Parse attachment information from content items. + + The content structure is: + - Index 0: User message + - Index 1: section with metadata + - Index 2+: Actual attachment content (one per attachment line) + + Args: + content: The content list of TextContentItems + + Returns: + List of attachment dictionaries or None if no attachments found + """ + if len(content) < 2: + return None + + attachments_info_text = content[1].get("text", "") + attachments_pattern = r"\n(.*?)\n" + attachments_match = re.search(attachments_pattern, attachments_info_text, re.DOTALL) + + if not attachments_match: + return None + + attachments_text = attachments_match.group(1) + attachments_info: list[dict[str, str]] = [] + + attachment_lines = [] + for line in attachments_text.strip().split("\n"): + line = line.strip() + if line: + attachment_lines.append(line) + # Parse: "attachment_type: value, content_type: value" + attachment_match = re.match( + r"attachment_type:\s*([^,]+),\s*content_type:\s*(.+)", line + ) + if attachment_match: + attachment_info = { + "attachment_type": attachment_match.group(1).strip(), + "content_type": attachment_match.group(2).strip(), + } + attachments_info.append(attachment_info) + + for i, attachment_info in enumerate(attachments_info): + content_index = 2 + i + if content_index < len(content): + attachment_info["content"] = content[content_index].get("text", "") + + return attachments_info if attachments_info else None + + conversation_responses: dict[int | str, dict[str, Any]] = { 200: { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", @@ -82,9 +137,14 @@ def simplify_session_data(session_data: Any) -> list[dict[str, Any]]: # Clean up input messages cleaned_messages = [] for msg in turn.get("input_messages", []): + content = msg.get("content", "") + # Parse attachments from content (handles both string and list of TextContentItems) + attachments = parse_attachments_from_content(content) + cleaned_msg = { - "content": msg.get("content"), + "content": content[0].get("text"), "type": msg.get("role"), # Rename role to type + "attachments": attachments, } cleaned_messages.append(cleaned_msg) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index d7629a02..1c209117 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -18,6 +18,7 @@ Toolgroup, ) from llama_stack_client.types.model_list_response import ModelListResponse +from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from fastapi import APIRouter, HTTPException, status, Depends @@ -295,8 +296,27 @@ def retrieve_response( # pylint: disable=too-many-locals toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ mcp_server.name for mcp_server in configuration.mcp_servers ] + + attachment_lines = [ + f"attachment_type: {attachment.attachment_type}, " + f"content_type: {attachment.content_type}" + for attachment in (query_request.attachments or []) + ] + response = agent.create_turn( - messages=[UserMessage(role="user", content=query_request.query)], + messages=[ + UserMessage( + role="user", + content=[ + TextContentItem(type="text", text=query_request.query), + TextContentItem( + type="text", + text=f"\n{'\n'.join(attachment_lines)}\n" + f"", + ), + ], + ) + ], session_id=conversation_id, documents=query_request.get_documents(), stream=False, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d322738d..919490c6 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -517,8 +517,9 @@ async def retrieve_response( available_output_shields, ) # use system prompt from request or default one - system_prompt = get_system_prompt(query_request, configuration) - logger.debug("Using system prompt: %s", system_prompt) + logger.debug( + "Using system prompt: %s", get_system_prompt(query_request, configuration) + ) # TODO(lucasagomes): redact attachments content before sending to LLM # if attachments are provided, validate them @@ -528,7 +529,7 @@ async def retrieve_response( agent, conversation_id = await get_agent( client, model_id, - system_prompt, + get_system_prompt(query_request, configuration), available_input_shields, available_output_shields, query_request.conversation_id, @@ -561,8 +562,28 @@ async def retrieve_response( toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ mcp_server.name for mcp_server in configuration.mcp_servers ] + + # Generate attachment info lines + attachment_lines = [ + f"attachment_type: {attachment.attachment_type}, " + f"content_type: {attachment.content_type}" + for attachment in (query_request.attachments or []) + ] + response = await agent.create_turn( - messages=[UserMessage(role="user", content=query_request.query)], + messages=[ + UserMessage( + role="user", + content=[ + TextContentItem(type="text", text=query_request.query), + TextContentItem( + type="text", + text=f"\n{'\n'.join(attachment_lines)}\n" + f"", + ), + ], + ) + ], session_id=conversation_id, documents=query_request.get_documents(), stream=True,