From e1d5c401dbfbc9ec044c2fcb4e65eb7ce610c456 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 13:23:56 +0300 Subject: [PATCH 01/19] refactor conversation to agent and conversation mapping --- src/app/endpoints/conversations.py | 25 ++----------- src/app/endpoints/query.py | 24 ++++--------- src/app/endpoints/streaming_query.py | 26 +++++--------- .../unit/app/endpoints/test_conversations.py | 35 ------------------- tests/unit/app/endpoints/test_query.py | 10 ------ .../app/endpoints/test_streaming_query.py | 10 ------ 6 files changed, 18 insertions(+), 112 deletions(-) diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index 0e761c60..7eff9bb5 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -18,8 +18,6 @@ router = APIRouter(tags=["conversations"]) auth_dependency = get_auth_dependency() -conversation_id_to_agent_id: dict[str, str] = {} - conversation_responses: dict[int | str, dict[str, Any]] = { 200: { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", @@ -131,17 +129,7 @@ def get_conversation_endpoint_handler( }, ) - agent_id = conversation_id_to_agent_id.get(conversation_id) - if not agent_id: - logger.error("Agent ID not found for conversation %s", conversation_id) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={ - "response": "conversation ID not found", - "cause": f"conversation ID {conversation_id} not found!", - }, - ) - + agent_id = conversation_id logger.info("Retrieving conversation %s", conversation_id) try: @@ -211,16 +199,7 @@ def delete_conversation_endpoint_handler( "cause": f"Conversation ID {conversation_id} is not a valid UUID", }, ) - agent_id = conversation_id_to_agent_id.get(conversation_id) - if not agent_id: - logger.error("Agent ID not found for conversation %s", conversation_id) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={ - "response": "conversation ID not found", - "cause": f"conversation ID {conversation_id} not found!", - }, - ) + agent_id = conversation_id logger.info("Deleting conversation %s", conversation_id) try: diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 8701aa6b..4f1f1249 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -1,6 +1,7 @@ """Handler for REST API call to provide answer to query.""" from datetime import datetime, UTC +from functools import lru_cache import json import logging import os @@ -23,7 +24,6 @@ from client import LlamaStackClientHolder from configuration import configuration -from app.endpoints.conversations import conversation_id_to_agent_id import metrics from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse from models.requests import QueryRequest, Attachment @@ -39,9 +39,6 @@ router = APIRouter(tags=["query"]) auth_dependency = get_auth_dependency() -# Global agent registry to persist agents across requests -_agent_cache: TTLCache[str, Agent] = TTLCache(maxsize=1000, ttl=3600) - query_response: dict[int | str, dict[str, Any]] = { 200: { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", @@ -73,6 +70,7 @@ def is_transcripts_enabled() -> bool: return configuration.user_data_collection_configuration.transcripts_enabled +@lru_cache def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments client: LlamaStackClient, model_id: str, @@ -83,15 +81,6 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen no_tools: bool = False, ) -> tuple[Agent, str]: """Get existing agent or create a new one with session persistence.""" - if conversation_id is not None: - agent = _agent_cache.get(conversation_id) - if agent: - logger.debug( - "Reusing existing agent with conversation_id: %s", conversation_id - ) - return agent, conversation_id - logger.debug("No existing agent found for conversation_id: %s", conversation_id) - logger.debug("Creating new agent") # TODO(lucasagomes): move to ReActAgent agent = Agent( @@ -103,10 +92,11 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), enable_session_persistence=True, ) - conversation_id = agent.create_session(get_suid()) - logger.debug("Created new agent and conversation_id: %s", conversation_id) - _agent_cache[conversation_id] = agent - conversation_id_to_agent_id[conversation_id] = agent.agent_id + if conversation_id: + agent.agent_id = conversation_id + else: + agent.create_session(get_suid()) + conversation_id = agent.agent_id return agent, conversation_id diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d27388e2..619831d8 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1,6 +1,7 @@ """Handler for REST API call to provide answer to streaming query.""" import ast +from functools import lru_cache import json import re import logging @@ -31,7 +32,6 @@ from utils.suid import get_suid from utils.types import GraniteToolParser -from app.endpoints.conversations import conversation_id_to_agent_id from app.endpoints.query import ( get_rag_toolgroups, is_input_shield, @@ -46,11 +46,9 @@ router = APIRouter(tags=["streaming_query"]) auth_dependency = get_auth_dependency() -# Global agent registry to persist agents across requests -_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600) - # # pylint: disable=R0913,R0917 +@lru_cache async def get_agent( client: AsyncLlamaStackClient, model_id: str, @@ -61,15 +59,6 @@ async def get_agent( no_tools: bool = False, ) -> tuple[AsyncAgent, str]: """Get existing agent or create a new one with session persistence.""" - if conversation_id is not None: - agent = _agent_cache.get(conversation_id) - if agent: - logger.debug( - "Reusing existing agent with conversation_id: %s", conversation_id - ) - return agent, conversation_id - logger.debug("No existing agent found for conversation_id: %s", conversation_id) - logger.debug("Creating new agent") agent = AsyncAgent( client, # type: ignore[arg-type] @@ -80,10 +69,13 @@ async def get_agent( tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), enable_session_persistence=True, ) - conversation_id = await agent.create_session(get_suid()) - logger.debug("Created new agent and conversation_id: %s", conversation_id) - _agent_cache[conversation_id] = agent - conversation_id_to_agent_id[conversation_id] = agent.agent_id + + if conversation_id: + agent._agent_id = conversation_id + else: + conversation_id = agent.agent_id + await agent.create_session(get_suid()) + return agent, conversation_id diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index 0f5456b7..0f7e33f2 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -7,7 +7,6 @@ from app.endpoints.conversations import ( get_conversation_endpoint_handler, delete_conversation_endpoint_handler, - conversation_id_to_agent_id, simplify_session_data, ) from models.responses import ConversationResponse, ConversationDeleteResponse @@ -48,16 +47,6 @@ def setup_configuration_fixture(): return cfg -@pytest.fixture(autouse=True) -def setup_conversation_mapping(): - """Set up and clean up the conversation ID to agent ID mapping.""" - # Clear the mapping before each test - conversation_id_to_agent_id.clear() - yield - # Clean up after each test - conversation_id_to_agent_id.clear() - - @pytest.fixture(name="mock_session_data") def mock_session_data_fixture(): """Create mock session data for testing.""" @@ -243,9 +232,6 @@ def test_llama_stack_connection_error(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise APIConnectionError mock_client = mocker.Mock() mock_client.agents.session.retrieve.side_effect = APIConnectionError( @@ -268,9 +254,6 @@ def test_llama_stack_not_found_error(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise NotFoundError mock_client = mocker.Mock() mock_client.agents.session.retrieve.side_effect = NotFoundError( @@ -294,9 +277,6 @@ def test_session_retrieve_exception(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise a general exception mock_client = mocker.Mock() mock_client.agents.session.retrieve.side_effect = Exception( @@ -323,9 +303,6 @@ def test_successful_conversation_retrieval( mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock session data with model_dump method mock_session_obj = mocker.Mock() mock_session_obj.model_dump.return_value = mock_session_data @@ -394,9 +371,6 @@ def test_llama_stack_connection_error(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise APIConnectionError mock_client = mocker.Mock() mock_client.agents.session.delete.side_effect = APIConnectionError(request=None) @@ -416,9 +390,6 @@ def test_llama_stack_not_found_error(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise NotFoundError mock_client = mocker.Mock() mock_client.agents.session.delete.side_effect = NotFoundError( @@ -442,9 +413,6 @@ def test_session_deletion_exception(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise a general exception mock_client = mocker.Mock() mock_client.agents.session.delete.side_effect = Exception( @@ -470,9 +438,6 @@ def test_successful_conversation_deletion(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder mock_client = mocker.Mock() mock_client.agents.session.delete.return_value = None # Successful deletion diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 03e32563..7d309f74 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -20,7 +20,6 @@ store_transcript, get_rag_toolgroups, get_agent, - _agent_cache, ) from models.requests import QueryRequest, Attachment @@ -65,8 +64,6 @@ def prepare_agent_mocks_fixture(mocker): mock_agent = mocker.Mock() mock_agent.create_turn.return_value.steps = [] yield mock_client, mock_agent - # cleanup agent cache after tests - _agent_cache.clear() def test_query_endpoint_handler_configuration_not_loaded(mocker): @@ -1079,7 +1076,6 @@ def test_get_agent_cache_hit(prepare_agent_mocks): # Set up cache with existing agent conversation_id = "test_conversation_id" - _agent_cache[conversation_id] = mock_agent result_agent, result_conversation_id = get_agent( client=mock_client, @@ -1146,9 +1142,6 @@ def test_get_agent_cache_miss_with_conversation_id( enable_session_persistence=True, ) - # Verify agent was stored in cache - assert _agent_cache["new_session_id"] == mock_agent - def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, mocker): """Test get_agent function when conversation_id is None.""" @@ -1199,9 +1192,6 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, enable_session_persistence=True, ) - # Verify agent was stored in cache - assert _agent_cache["new_session_id"] == mock_agent - def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocker): """Test get_agent function with empty shields list.""" diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 8ff286ad..368d4990 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -41,7 +41,6 @@ retrieve_response, stream_build_event, get_agent, - _agent_cache, ) from models.requests import QueryRequest, Attachment @@ -113,8 +112,6 @@ def prepare_agent_mocks_fixture(mocker): mock_client = mocker.AsyncMock() mock_agent = mocker.AsyncMock() yield mock_client, mock_agent - # cleanup agent cache after tests - _agent_cache.clear() @pytest.mark.asyncio @@ -1221,7 +1218,6 @@ async def test_get_agent_cache_hit(prepare_agent_mocks): # Set up cache with existing agent conversation_id = "test_conversation_id" - _agent_cache[conversation_id] = mock_agent result_agent, result_conversation_id = await get_agent( client=mock_client, @@ -1292,9 +1288,6 @@ async def test_get_agent_cache_miss_with_conversation_id( enable_session_persistence=True, ) - # Verify agent was stored in cache - assert _agent_cache["new_session_id"] == mock_agent - @pytest.mark.asyncio async def test_get_agent_no_conversation_id( @@ -1351,9 +1344,6 @@ async def test_get_agent_no_conversation_id( enable_session_persistence=True, ) - # Verify agent was stored in cache - assert _agent_cache["new_session_id"] == mock_agent - @pytest.mark.asyncio async def test_get_agent_empty_shields( From 20ead26d8d67d2785354a421cb5593a7d7f8346f Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 15:01:01 +0300 Subject: [PATCH 02/19] delete orphan agents --- src/app/endpoints/query.py | 4 +++- src/app/endpoints/streaming_query.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 4f1f1249..3759e773 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -93,10 +93,12 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen enable_session_persistence=True, ) if conversation_id: + orphan_agent_id = agent.agent_id agent.agent_id = conversation_id + client.agents.delete(agent_id=orphan_agent_id) else: - agent.create_session(get_suid()) conversation_id = agent.agent_id + agent.create_session(get_suid()) return agent, conversation_id diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 619831d8..e927d63a 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -71,7 +71,9 @@ async def get_agent( ) if conversation_id: + orphan_agent_id = agent.agent_id agent._agent_id = conversation_id + await client.agents.delete(agent_id=orphan_agent_id) else: conversation_id = agent.agent_id await agent.create_session(get_suid()) From 26e2de53cb28fa03f481dbfbae775aa23a41f85e Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 15:42:59 +0300 Subject: [PATCH 03/19] check if agent exists --- src/app/endpoints/query.py | 7 ++++++- src/app/endpoints/streaming_query.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3759e773..a9459406 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -81,6 +81,11 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen no_tools: bool = False, ) -> tuple[Agent, str]: """Get existing agent or create a new one with session persistence.""" + existing_agent_id = None + if conversation_id: + agent_reponse = client.agents.retrieve(agent_id=conversation_id) + existing_agent_id = agent_reponse.agent_id + logger.debug("Creating new agent") # TODO(lucasagomes): move to ReActAgent agent = Agent( @@ -92,7 +97,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), enable_session_persistence=True, ) - if conversation_id: + if existing_agent_id and conversation_id: orphan_agent_id = agent.agent_id agent.agent_id = conversation_id client.agents.delete(agent_id=orphan_agent_id) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index e927d63a..bde1986b 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -59,6 +59,11 @@ async def get_agent( no_tools: bool = False, ) -> tuple[AsyncAgent, str]: """Get existing agent or create a new one with session persistence.""" + existing_agent_id = None + if conversation_id: + agent_reponse = await client.agents.retrieve(agent_id=conversation_id) + existing_agent_id = agent_reponse.agent_id + logger.debug("Creating new agent") agent = AsyncAgent( client, # type: ignore[arg-type] @@ -70,7 +75,7 @@ async def get_agent( enable_session_persistence=True, ) - if conversation_id: + if existing_agent_id and conversation_id: orphan_agent_id = agent.agent_id agent._agent_id = conversation_id await client.agents.delete(agent_id=orphan_agent_id) From a11fcf4c1ca0d455f2bf9072823a0e311db129fb Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 15:46:58 +0300 Subject: [PATCH 04/19] remove cache for now --- src/app/endpoints/query.py | 2 -- src/app/endpoints/streaming_query.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index a9459406..0886ee0d 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -1,7 +1,6 @@ """Handler for REST API call to provide answer to query.""" from datetime import datetime, UTC -from functools import lru_cache import json import logging import os @@ -70,7 +69,6 @@ def is_transcripts_enabled() -> bool: return configuration.user_data_collection_configuration.transcripts_enabled -@lru_cache def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments client: LlamaStackClient, model_id: str, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index bde1986b..722b8eb7 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1,7 +1,6 @@ """Handler for REST API call to provide answer to streaming query.""" import ast -from functools import lru_cache import json import re import logging @@ -48,7 +47,6 @@ # # pylint: disable=R0913,R0917 -@lru_cache async def get_agent( client: AsyncLlamaStackClient, model_id: str, From 89e8678eaf38de4c1d1da225bb33a4ddcd27f3d2 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 15:51:10 +0300 Subject: [PATCH 05/19] handle value error from API to implement "not found" --- src/app/endpoints/query.py | 5 +++-- src/app/endpoints/streaming_query.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 0886ee0d..73ca053f 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -1,5 +1,6 @@ """Handler for REST API call to provide answer to query.""" +from contextlib import suppress from datetime import datetime, UTC import json import logging @@ -81,8 +82,8 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen """Get existing agent or create a new one with session persistence.""" existing_agent_id = None if conversation_id: - agent_reponse = client.agents.retrieve(agent_id=conversation_id) - existing_agent_id = agent_reponse.agent_id + with suppress(ValueError): + existing_agent_id = client.agents.retrieve(agent_id=conversation_id).agent_id logger.debug("Creating new agent") # TODO(lucasagomes): move to ReActAgent diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 722b8eb7..cfa510fd 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1,6 +1,7 @@ """Handler for REST API call to provide answer to streaming query.""" import ast +from contextlib import suppress import json import re import logging @@ -59,8 +60,8 @@ async def get_agent( """Get existing agent or create a new one with session persistence.""" existing_agent_id = None if conversation_id: - agent_reponse = await client.agents.retrieve(agent_id=conversation_id) - existing_agent_id = agent_reponse.agent_id + with suppress(ValueError): + existing_agent_id = (await client.agents.retrieve(agent_id=conversation_id)).agent_id logger.debug("Creating new agent") agent = AsyncAgent( From 3c368c07d766b10c4eb93a96f6a734553387c2e4 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 17:18:54 +0300 Subject: [PATCH 06/19] internal session work --- src/app/endpoints/conversations.py | 9 +++------ src/app/endpoints/query.py | 2 +- src/app/endpoints/streaming_query.py | 2 +- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index 7eff9bb5..38cfd25c 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -67,16 +67,15 @@ } -def simplify_session_data(session_data: Any) -> list[dict[str, Any]]: +def simplify_session_data(session_dict: dict) -> list[dict[str, Any]]: """Simplify session data to include only essential conversation information. Args: - session_data: The full session data from llama-stack + session_dict: The full session data dict from llama-stack Returns: Simplified session data with only input_messages and output_message per turn """ - session_dict = session_data.model_dump() # Create simplified structure chat_history = [] @@ -135,9 +134,7 @@ def get_conversation_endpoint_handler( try: client = LlamaStackClientHolder().get_client() - session_data = client.agents.session.retrieve( - agent_id=agent_id, session_id=conversation_id - ) + session_data = client.agents.session.list(agent_id=agent_id).data[0] logger.info("Successfully retrieved conversation %s", conversation_id) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 73ca053f..b108021c 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -322,7 +322,7 @@ def retrieve_response( # pylint: disable=too-many-locals response = agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], - session_id=conversation_id, + session_id=agent.session_id, documents=query_request.get_documents(), stream=False, toolgroups=toolgroups, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index cfa510fd..0b616300 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -574,7 +574,7 @@ async def retrieve_response( logger.debug("Session ID: %s", conversation_id) response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], - session_id=conversation_id, + session_id=agent.session_id, documents=query_request.get_documents(), stream=True, toolgroups=toolgroups, From f360f038064536677bbc36d5ec257296d4b5c5c8 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 17:24:36 +0300 Subject: [PATCH 07/19] sessions array --- src/app/endpoints/query.py | 2 +- src/app/endpoints/streaming_query.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index b108021c..94e6ffeb 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -322,7 +322,7 @@ def retrieve_response( # pylint: disable=too-many-locals response = agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], - session_id=agent.session_id, + session_id=agent.sessions[0], documents=query_request.get_documents(), stream=False, toolgroups=toolgroups, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 0b616300..371705c7 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -574,7 +574,7 @@ async def retrieve_response( logger.debug("Session ID: %s", conversation_id) response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], - session_id=agent.session_id, + session_id=agent.sessions[0], documents=query_request.get_documents(), stream=True, toolgroups=toolgroups, From f4728144efde16da424aabeecf92e95a1e60f141 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 17:32:33 +0300 Subject: [PATCH 08/19] do not pass session id (internal handling exists) --- src/app/endpoints/query.py | 1 - src/app/endpoints/streaming_query.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 94e6ffeb..7f7e7b6a 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -322,7 +322,6 @@ def retrieve_response( # pylint: disable=too-many-locals response = agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], - session_id=agent.sessions[0], documents=query_request.get_documents(), stream=False, toolgroups=toolgroups, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 371705c7..50918929 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -574,7 +574,6 @@ async def retrieve_response( logger.debug("Session ID: %s", conversation_id) response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], - session_id=agent.sessions[0], documents=query_request.get_documents(), stream=True, toolgroups=toolgroups, From 3318bcdd778c887fd57a54fda377581b821ef885 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 23:05:26 +0300 Subject: [PATCH 09/19] pass session_id --- src/app/endpoints/query.py | 10 ++++++---- src/app/endpoints/streaming_query.py | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 7f7e7b6a..1153f995 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -78,7 +78,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen available_output_shields: list[str], conversation_id: str | None, no_tools: bool = False, -) -> tuple[Agent, str]: +) -> tuple[Agent, str, str]: """Get existing agent or create a new one with session persistence.""" existing_agent_id = None if conversation_id: @@ -99,12 +99,13 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen if existing_agent_id and conversation_id: orphan_agent_id = agent.agent_id agent.agent_id = conversation_id + session_id = agent.session_id client.agents.delete(agent_id=orphan_agent_id) else: conversation_id = agent.agent_id - agent.create_session(get_suid()) + session_id = agent.create_session(get_suid()) - return agent, conversation_id + return agent, conversation_id, session_id @router.post("/query", responses=query_response) @@ -278,7 +279,7 @@ def retrieve_response( # pylint: disable=too-many-locals if query_request.attachments: validate_attachments_metadata(query_request.attachments) - agent, conversation_id = get_agent( + agent, conversation_id, session_id = get_agent( client, model_id, system_prompt, @@ -322,6 +323,7 @@ def retrieve_response( # pylint: disable=too-many-locals response = agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], + session_id=session_id, documents=query_request.get_documents(), stream=False, toolgroups=toolgroups, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 50918929..d10f4685 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -56,7 +56,7 @@ async def get_agent( available_output_shields: list[str], conversation_id: str | None, no_tools: bool = False, -) -> tuple[AsyncAgent, str]: +) -> tuple[AsyncAgent, str, str]: """Get existing agent or create a new one with session persistence.""" existing_agent_id = None if conversation_id: @@ -77,12 +77,13 @@ async def get_agent( if existing_agent_id and conversation_id: orphan_agent_id = agent.agent_id agent._agent_id = conversation_id + session_id = agent.session_id await client.agents.delete(agent_id=orphan_agent_id) else: conversation_id = agent.agent_id - await agent.create_session(get_suid()) + session_id = await agent.create_session(get_suid()) - return agent, conversation_id + return agent, conversation_id, session_id METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n") @@ -524,7 +525,7 @@ async def retrieve_response( if query_request.attachments: validate_attachments_metadata(query_request.attachments) - agent, conversation_id = await get_agent( + agent, conversation_id, session_id = await get_agent( client, model_id, system_prompt, @@ -574,6 +575,7 @@ async def retrieve_response( logger.debug("Session ID: %s", conversation_id) response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], + session_id=session_id, documents=query_request.get_documents(), stream=True, toolgroups=toolgroups, From abb138443a2ca4fe9dcc099aebf188422b582a9d Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 23:12:08 +0300 Subject: [PATCH 10/19] use agent sessions simply --- src/app/endpoints/query.py | 2 +- src/app/endpoints/streaming_query.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 1153f995..dadae9df 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -99,7 +99,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen if existing_agent_id and conversation_id: orphan_agent_id = agent.agent_id agent.agent_id = conversation_id - session_id = agent.session_id + session_id = agent.sessions[0] client.agents.delete(agent_id=orphan_agent_id) else: conversation_id = agent.agent_id diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d10f4685..c946aedb 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -77,7 +77,7 @@ async def get_agent( if existing_agent_id and conversation_id: orphan_agent_id = agent.agent_id agent._agent_id = conversation_id - session_id = agent.session_id + session_id = agent.sessions[0] await client.agents.delete(agent_id=orphan_agent_id) else: conversation_id = agent.agent_id From 92cf14ee02662bb546c89be13ad685e291f78390 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 23:34:31 +0300 Subject: [PATCH 11/19] test session id list --- src/app/endpoints/query.py | 3 +++ src/app/endpoints/streaming_query.py | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index dadae9df..33cd95a0 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -101,6 +101,9 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen agent.agent_id = conversation_id session_id = agent.sessions[0] client.agents.delete(agent_id=orphan_agent_id) + sessions_response = client.agents.session.list(agent_id=conversation_id) + logger.info(f"session response: {sessions_response}") + session_id = str(sessions_response.data[0]["id"]) else: conversation_id = agent.agent_id session_id = agent.create_session(get_suid()) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index c946aedb..b06280b9 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -61,7 +61,8 @@ async def get_agent( existing_agent_id = None if conversation_id: with suppress(ValueError): - existing_agent_id = (await client.agents.retrieve(agent_id=conversation_id)).agent_id + agent_response = await client.agents.retrieve(agent_id=conversation_id) + existing_agent_id = agent_response.agent_id logger.debug("Creating new agent") agent = AsyncAgent( @@ -77,8 +78,10 @@ async def get_agent( if existing_agent_id and conversation_id: orphan_agent_id = agent.agent_id agent._agent_id = conversation_id - session_id = agent.sessions[0] await client.agents.delete(agent_id=orphan_agent_id) + sessions_response = await client.agents.session.list(agent_id=conversation_id) + logger.info(f"session response: {sessions_response}") + session_id = str(sessions_response.data[0]["id"]) else: conversation_id = agent.agent_id session_id = await agent.create_session(get_suid()) From 50b652aa77d62f7c289bf28af03a5b585d97c237 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Wed, 30 Jul 2025 23:45:42 +0300 Subject: [PATCH 12/19] session id updates --- src/app/endpoints/query.py | 3 +-- src/app/endpoints/streaming_query.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 33cd95a0..5fc04706 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -99,11 +99,10 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen if existing_agent_id and conversation_id: orphan_agent_id = agent.agent_id agent.agent_id = conversation_id - session_id = agent.sessions[0] client.agents.delete(agent_id=orphan_agent_id) sessions_response = client.agents.session.list(agent_id=conversation_id) logger.info(f"session response: {sessions_response}") - session_id = str(sessions_response.data[0]["id"]) + session_id = str(sessions_response.data[0]["session_id"]) else: conversation_id = agent.agent_id session_id = agent.create_session(get_suid()) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index b06280b9..c5ef28e0 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -81,7 +81,7 @@ async def get_agent( await client.agents.delete(agent_id=orphan_agent_id) sessions_response = await client.agents.session.list(agent_id=conversation_id) logger.info(f"session response: {sessions_response}") - session_id = str(sessions_response.data[0]["id"]) + session_id = str(sessions_response.data[0]["session_id"]) else: conversation_id = agent.agent_id session_id = await agent.create_session(get_suid()) From 20b1df58fec0cc416f1e7f88e325fd1dc8ef5947 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Mon, 4 Aug 2025 17:17:53 +0300 Subject: [PATCH 13/19] fix tests/lint/mypy (5 failures remain) --- src/app/endpoints/query.py | 8 +- src/app/endpoints/streaming_query.py | 6 +- .../unit/app/endpoints/test_conversations.py | 21 +--- tests/unit/app/endpoints/test_query.py | 113 +++++++++++------- .../app/endpoints/test_streaming_query.py | 92 ++++++++------ 5 files changed, 137 insertions(+), 103 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 5fc04706..774907d5 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -8,8 +8,6 @@ from pathlib import Path from typing import Any -from cachetools import TTLCache # type: ignore - from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import APIConnectionError from llama_stack_client import LlamaStackClient # type: ignore @@ -83,7 +81,9 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen existing_agent_id = None if conversation_id: with suppress(ValueError): - existing_agent_id = client.agents.retrieve(agent_id=conversation_id).agent_id + existing_agent_id = client.agents.retrieve( + agent_id=conversation_id + ).agent_id logger.debug("Creating new agent") # TODO(lucasagomes): move to ReActAgent @@ -101,7 +101,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen agent.agent_id = conversation_id client.agents.delete(agent_id=orphan_agent_id) sessions_response = client.agents.session.list(agent_id=conversation_id) - logger.info(f"session response: {sessions_response}") + logger.info("session response: %s", sessions_response) session_id = str(sessions_response.data[0]["session_id"]) else: conversation_id = agent.agent_id diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index c5ef28e0..e4663327 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -7,8 +7,6 @@ import logging from typing import Any, AsyncIterator, Iterator -from cachetools import TTLCache # type: ignore - from llama_stack_client import APIConnectionError from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore from llama_stack_client import AsyncLlamaStackClient # type: ignore @@ -77,10 +75,10 @@ async def get_agent( if existing_agent_id and conversation_id: orphan_agent_id = agent.agent_id - agent._agent_id = conversation_id + agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access await client.agents.delete(agent_id=orphan_agent_id) sessions_response = await client.agents.session.list(agent_id=conversation_id) - logger.info(f"session response: {sessions_response}") + logger.info("session response: %s", sessions_response) session_id = str(sessions_response.data[0]["session_id"]) else: conversation_id = agent.agent_id diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index 0f7e33f2..9fb91311 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -116,17 +116,12 @@ class TestSimplifySessionData: """Test cases for the simplify_session_data function.""" def test_simplify_session_data_with_model_dump( - self, mock_session_data, expected_chat_history, mocker + self, mock_session_data, expected_chat_history ): - """Test simplify_session_data with session data that has model_dump method.""" - # Create a mock object with model_dump method - mock_session_obj = mocker.Mock() - mock_session_obj.model_dump.return_value = mock_session_data - - result = simplify_session_data(mock_session_obj) + """Test simplify_session_data with session data.""" + result = simplify_session_data(mock_session_data) assert result == expected_chat_history - mock_session_obj.model_dump.assert_called_once() def test_simplify_session_data_empty_turns(self, mocker): """Test simplify_session_data with empty turns.""" @@ -136,10 +131,7 @@ def test_simplify_session_data_empty_turns(self, mocker): "turns": [], } - mock_session_obj = mocker.Mock() - mock_session_obj.model_dump.return_value = session_data - - result = simplify_session_data(mock_session_obj) + result = simplify_session_data(session_data) assert not result @@ -172,9 +164,8 @@ def test_simplify_session_data_filters_unwanted_fields(self, mocker): } mock_session_obj = mocker.Mock() - mock_session_obj.model_dump.return_value = session_data - result = simplify_session_data(mock_session_obj) + result = simplify_session_data(session_data) expected = [ { @@ -455,5 +446,5 @@ def test_successful_conversation_deletion(self, mocker, setup_configuration): assert response.success is True assert response.response == "Conversation deleted successfully" mock_client.agents.session.delete.assert_called_once_with( - agent_id=VALID_AGENT_ID, session_id=VALID_CONVERSATION_ID + agent_id=VALID_CONVERSATION_ID, session_id=VALID_CONVERSATION_ID ) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 7d309f74..2e5ba07b 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -377,7 +377,8 @@ def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -391,7 +392,7 @@ def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): # Assert that the metric for validation errors is NOT incremented mock_metric.inc.assert_not_called() assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -413,7 +414,8 @@ def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -425,7 +427,7 @@ def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -460,7 +462,8 @@ def __repr__(self): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -472,7 +475,7 @@ def __repr__(self): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -510,7 +513,8 @@ def __repr__(self): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -522,7 +526,7 @@ def __repr__(self): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -562,7 +566,8 @@ def __repr__(self): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -574,7 +579,7 @@ def __repr__(self): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -616,7 +621,8 @@ def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): ), ] mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) @@ -628,7 +634,7 @@ def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -668,7 +674,8 @@ def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): ), ] mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) @@ -680,7 +687,7 @@ def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -721,7 +728,8 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -733,7 +741,7 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -786,7 +794,8 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -798,7 +807,7 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -845,7 +854,8 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers( mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -871,7 +881,7 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers( ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -936,7 +946,8 @@ def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -948,7 +959,7 @@ def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): # Assert that the metric for validation errors is incremented mock_metric.inc.assert_called_once() - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -1070,14 +1081,22 @@ def test_query_endpoint_handler_on_connection_error(mocker): mock_metric.inc.assert_called_once() -def test_get_agent_cache_hit(prepare_agent_mocks): +def test_get_agent_cache_hit(prepare_agent_mocks, mocker): """Test get_agent function when agent exists in cache.""" mock_client, mock_agent = prepare_agent_mocks + mock_client.agents.session.list.return_value = mocker.Mock( + data=[{"session_id": "test_session_id"}] + ) # Set up cache with existing agent conversation_id = "test_conversation_id" - result_agent, result_conversation_id = get_agent( + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.query.Agent", return_value=mock_agent + ) + + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1088,7 +1107,8 @@ def test_get_agent_cache_hit(prepare_agent_mocks): # Assert cached agent is returned assert result_agent == mock_agent - assert result_conversation_id == conversation_id + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "test_session_id" def test_get_agent_cache_miss_with_conversation_id( @@ -1096,6 +1116,9 @@ def test_get_agent_cache_miss_with_conversation_id( ): """Test get_agent function when conversation_id is provided but agent not in cache.""" mock_client, mock_agent = prepare_agent_mocks + mock_client.agents.retrieve.side_effect = ValueError( + "fake not finding existing agent" + ) mock_agent.create_session.return_value = "new_session_id" # Mock Agent class @@ -1118,7 +1141,7 @@ def test_get_agent_cache_miss_with_conversation_id( mocker.patch("app.endpoints.query.configuration", setup_configuration) # Call function with conversation_id but no cached agent - result_agent, result_conversation_id = get_agent( + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1129,7 +1152,8 @@ def test_get_agent_cache_miss_with_conversation_id( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with correct parameters mock_agent_class.assert_called_once_with( @@ -1168,7 +1192,7 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, mocker.patch("app.endpoints.query.configuration", setup_configuration) # Call function with None conversation_id - result_agent, result_conversation_id = get_agent( + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1179,7 +1203,8 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with correct parameters mock_agent_class.assert_called_once_with( @@ -1218,7 +1243,7 @@ def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocke mocker.patch("app.endpoints.query.configuration", setup_configuration) # Call function with empty shields list - result_agent, result_conversation_id = get_agent( + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1229,7 +1254,8 @@ def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocke # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with empty shields mock_agent_class.assert_called_once_with( @@ -1272,7 +1298,7 @@ def test_get_agent_multiple_mcp_servers( mocker.patch("app.endpoints.query.configuration", setup_configuration) # Call function - result_agent, result_conversation_id = get_agent( + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1283,7 +1309,8 @@ def test_get_agent_multiple_mcp_servers( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with tools from both MCP servers mock_agent_class.assert_called_once_with( @@ -1469,7 +1496,8 @@ def test_retrieve_response_no_tools_bypasses_mcp_and_rag(prepare_agent_mocks, mo mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", no_tools=True) @@ -1481,7 +1509,7 @@ def test_retrieve_response_no_tools_bypasses_mcp_and_rag(prepare_agent_mocks, mo ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify that agent.extra_headers is empty (no MCP headers) assert mock_agent.extra_headers == {} @@ -1517,7 +1545,8 @@ def test_retrieve_response_no_tools_false_preserves_functionality( mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", no_tools=False) @@ -1529,7 +1558,7 @@ def test_retrieve_response_no_tools_false_preserves_functionality( ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify that agent.extra_headers contains MCP headers expected_extra_headers = { @@ -1579,7 +1608,7 @@ def test_get_agent_no_tools_no_parser(setup_configuration, prepare_agent_mocks, mocker.patch("app.endpoints.query.configuration", setup_configuration) # Call function with no_tools=True - result_agent, result_conversation_id = get_agent( + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1591,7 +1620,8 @@ def test_get_agent_no_tools_no_parser(setup_configuration, prepare_agent_mocks, # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with tool_parser=None mock_agent_class.assert_called_once_with( @@ -1637,7 +1667,7 @@ def test_get_agent_no_tools_false_preserves_parser( mocker.patch("app.endpoints.query.configuration", setup_configuration) # Call function with no_tools=False - result_agent, result_conversation_id = get_agent( + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1649,7 +1679,8 @@ def test_get_agent_no_tools_false_preserves_parser( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with the proper tool_parser mock_agent_class.assert_called_once_with( diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 368d4990..0a9114a0 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -314,7 +314,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -331,7 +331,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=get_rag_toolgroups(["VectorDB-1"]), @@ -351,7 +351,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -368,7 +368,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=None, @@ -401,7 +401,7 @@ def __repr__(self): mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -416,7 +416,7 @@ def __repr__(self): assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=None, @@ -452,7 +452,7 @@ def __repr__(self): mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -467,7 +467,7 @@ def __repr__(self): assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=None, @@ -505,7 +505,7 @@ def __repr__(self): mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -532,7 +532,7 @@ def __repr__(self): mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=None, @@ -560,7 +560,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker ] mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) @@ -575,7 +575,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", stream=True, # Should be True for streaming endpoint documents=[ { @@ -613,7 +613,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke ] mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) @@ -628,7 +628,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", stream=True, # Should be True for streaming endpoint documents=[ { @@ -1014,7 +1014,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -1057,7 +1057,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): # Check that create_turn was called with the correct parameters mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, toolgroups=[mcp_server.name for mcp_server in mcp_servers], @@ -1082,7 +1082,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -1116,7 +1116,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( # Check that create_turn was called with the correct parameters mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, toolgroups=[mcp_server.name for mcp_server in mcp_servers], @@ -1147,7 +1147,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -1203,7 +1203,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): # Check that create_turn was called with the correct parameters mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, toolgroups=[mcp_server.name for mcp_server in mcp_servers], @@ -1211,7 +1211,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): @pytest.mark.asyncio -async def test_get_agent_cache_hit(prepare_agent_mocks): +async def test_get_agent_cache_hit(prepare_agent_mocks, mocker): """Test get_agent function when agent exists in cache.""" mock_client, mock_agent = prepare_agent_mocks @@ -1219,7 +1219,12 @@ async def test_get_agent_cache_hit(prepare_agent_mocks): # Set up cache with existing agent conversation_id = "test_conversation_id" - result_agent, result_conversation_id = await get_agent( + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent + ) + + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1240,6 +1245,9 @@ async def test_get_agent_cache_miss_with_conversation_id( """Test get_agent function when conversation_id is provided but agent not in cache.""" mock_client, mock_agent = prepare_agent_mocks + mock_client.agents.retrieve.side_effect = ValueError( + "fake not finding existing agent" + ) mock_agent.create_session.return_value = "new_session_id" # Mock Agent class @@ -1264,7 +1272,7 @@ async def test_get_agent_cache_miss_with_conversation_id( mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) # Call function with conversation_id but no cached agent - result_agent, result_conversation_id = await get_agent( + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1275,7 +1283,8 @@ async def test_get_agent_cache_miss_with_conversation_id( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with correct parameters mock_agent_class.assert_called_once_with( @@ -1320,7 +1329,7 @@ async def test_get_agent_no_conversation_id( mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) # Call function with None conversation_id - result_agent, result_conversation_id = await get_agent( + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1331,7 +1340,8 @@ async def test_get_agent_no_conversation_id( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with correct parameters mock_agent_class.assert_called_once_with( @@ -1376,7 +1386,7 @@ async def test_get_agent_empty_shields( mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) # Call function with empty shields list - result_agent, result_conversation_id = await get_agent( + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1387,7 +1397,8 @@ async def test_get_agent_empty_shields( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with empty shields mock_agent_class.assert_called_once_with( @@ -1434,7 +1445,7 @@ async def test_get_agent_multiple_mcp_servers( mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) # Call function - result_agent, result_conversation_id = await get_agent( + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1445,7 +1456,8 @@ async def test_get_agent_multiple_mcp_servers( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with tools from both MCP servers mock_agent_class.assert_called_once_with( @@ -1546,7 +1558,7 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): "app.endpoints.streaming_query.retrieve_user_id", return_value="user123" ) - _ = await streaming_query_endpoint_handler( + result_session_id = await streaming_query_endpoint_handler( None, QueryRequest(query="test query"), auth=("user123", "username", "auth_token_123"), @@ -1659,7 +1671,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "fake_session_id"), + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", no_tools=True) @@ -1671,7 +1683,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( ) assert response is not None - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify that agent.extra_headers is empty (no MCP headers) assert mock_agent.extra_headers == {} @@ -1709,7 +1721,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "fake_session_id"), + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", no_tools=False) @@ -1721,7 +1733,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( ) assert response is not None - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify that agent.extra_headers contains MCP headers expected_extra_headers = { @@ -1775,7 +1787,7 @@ async def test_get_agent_no_tools_no_parser( mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) # Call function with no_tools=True - result_agent, result_conversation_id = await get_agent( + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1787,7 +1799,8 @@ async def test_get_agent_no_tools_no_parser( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with tool_parser=None mock_agent_class.assert_called_once_with( @@ -1838,7 +1851,7 @@ async def test_get_agent_no_tools_false_preserves_parser( mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) # Call function with no_tools=False - result_agent, result_conversation_id = await get_agent( + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1850,7 +1863,8 @@ async def test_get_agent_no_tools_false_preserves_parser( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with the proper tool_parser mock_agent_class.assert_called_once_with( From baab0975fc06493c2d2023a2c64213d66a3e1e74 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Mon, 4 Aug 2025 18:18:49 +0300 Subject: [PATCH 14/19] fix/remove conversations tests --- .../unit/app/endpoints/test_conversations.py | 59 ++----------------- 1 file changed, 6 insertions(+), 53 deletions(-) diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index 9fb91311..ea27f0f2 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -206,18 +206,6 @@ def test_invalid_conversation_id_format(self, mocker, setup_configuration): assert "Invalid conversation ID format" in exc_info.value.detail["response"] assert INVALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_conversation_not_found_in_mapping(self, mocker, setup_configuration): - """Test the endpoint when conversation ID is not in the mapping.""" - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - - with pytest.raises(HTTPException) as exc_info: - get_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) - - assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND - assert "conversation ID not found" in exc_info.value.detail["response"] - assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_llama_stack_connection_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack connection fails.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) @@ -225,9 +213,7 @@ def test_llama_stack_connection_error(self, mocker, setup_configuration): # Mock LlamaStackClientHolder to raise APIConnectionError mock_client = mocker.Mock() - mock_client.agents.session.retrieve.side_effect = APIConnectionError( - request=None - ) + mock_client.agents.session.list.side_effect = APIConnectionError(request=None) mock_client_holder = mocker.patch( "app.endpoints.conversations.LlamaStackClientHolder" ) @@ -240,29 +226,6 @@ def test_llama_stack_connection_error(self, mocker, setup_configuration): assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE assert "Unable to connect to Llama Stack" in exc_info.value.detail["response"] - def test_llama_stack_not_found_error(self, mocker, setup_configuration): - """Test the endpoint when LlamaStack returns NotFoundError.""" - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - - # Mock LlamaStackClientHolder to raise NotFoundError - mock_client = mocker.Mock() - mock_client.agents.session.retrieve.side_effect = NotFoundError( - message="Session not found", response=mocker.Mock(request=None), body=None - ) - mock_client_holder = mocker.patch( - "app.endpoints.conversations.LlamaStackClientHolder" - ) - mock_client_holder.return_value.get_client.return_value = mock_client - - with pytest.raises(HTTPException) as exc_info: - get_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) - - assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND - assert "Conversation not found" in exc_info.value.detail["response"] - assert "could not be retrieved" in exc_info.value.detail["cause"] - assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_session_retrieve_exception(self, mocker, setup_configuration): """Test the endpoint when session retrieval raises an exception.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) @@ -300,7 +263,9 @@ def test_successful_conversation_retrieval( # Mock LlamaStackClientHolder mock_client = mocker.Mock() - mock_client.agents.session.retrieve.return_value = mock_session_obj + mock_client.agents.session.list.return_value = mocker.Mock( + data=[mock_session_data] + ) mock_client_holder = mocker.patch( "app.endpoints.conversations.LlamaStackClientHolder" ) @@ -313,8 +278,8 @@ def test_successful_conversation_retrieval( assert isinstance(response, ConversationResponse) assert response.conversation_id == VALID_CONVERSATION_ID assert response.chat_history == expected_chat_history - mock_client.agents.session.retrieve.assert_called_once_with( - agent_id=VALID_AGENT_ID, session_id=VALID_CONVERSATION_ID + mock_client.agents.session.list.assert_called_once_with( + agent_id=VALID_CONVERSATION_ID ) @@ -345,18 +310,6 @@ def test_invalid_conversation_id_format(self, mocker, setup_configuration): assert "Invalid conversation ID format" in exc_info.value.detail["response"] assert INVALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_conversation_not_found_in_mapping(self, mocker, setup_configuration): - """Test the endpoint when conversation ID is not in the mapping.""" - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - - with pytest.raises(HTTPException) as exc_info: - delete_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) - - assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND - assert "conversation ID not found" in exc_info.value.detail["response"] - assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_llama_stack_connection_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack connection fails.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) From 8b263676f0fff8e01b048e193c2728ef11d15b8e Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Mon, 4 Aug 2025 19:01:30 +0300 Subject: [PATCH 15/19] more tests fixes --- tests/unit/app/endpoints/test_conversations.py | 6 ++---- tests/unit/app/endpoints/test_query.py | 4 +--- tests/unit/app/endpoints/test_streaming_query.py | 8 +++----- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index ea27f0f2..6a100289 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -123,7 +123,7 @@ def test_simplify_session_data_with_model_dump( assert result == expected_chat_history - def test_simplify_session_data_empty_turns(self, mocker): + def test_simplify_session_data_empty_turns(self): """Test simplify_session_data with empty turns.""" session_data = { "session_id": VALID_CONVERSATION_ID, @@ -135,7 +135,7 @@ def test_simplify_session_data_empty_turns(self, mocker): assert not result - def test_simplify_session_data_filters_unwanted_fields(self, mocker): + def test_simplify_session_data_filters_unwanted_fields(self): """Test that simplify_session_data properly filters out unwanted fields.""" session_data = { "session_id": VALID_CONVERSATION_ID, @@ -163,8 +163,6 @@ def test_simplify_session_data_filters_unwanted_fields(self, mocker): ], } - mock_session_obj = mocker.Mock() - result = simplify_session_data(session_data) expected = [ diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 2e5ba07b..b5cd18e5 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -1092,9 +1092,7 @@ def test_get_agent_cache_hit(prepare_agent_mocks, mocker): conversation_id = "test_conversation_id" # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.query.Agent", return_value=mock_agent - ) + mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 0a9114a0..6975ce51 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1220,11 +1220,9 @@ async def test_get_agent_cache_hit(prepare_agent_mocks, mocker): conversation_id = "test_conversation_id" # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent - ) + mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent) - result_agent, result_conversation_id, result_session_id = await get_agent( + result_agent, result_conversation_id, _ = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1558,7 +1556,7 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): "app.endpoints.streaming_query.retrieve_user_id", return_value="user123" ) - result_session_id = await streaming_query_endpoint_handler( + await streaming_query_endpoint_handler( None, QueryRequest(query="test query"), auth=("user123", "username", "auth_token_123"), From 0d5306ed2ad6714709f810363c1cd453dcb43032 Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Mon, 4 Aug 2025 21:49:57 +0300 Subject: [PATCH 16/19] Fixed UTs Signed-off-by: Eran Cohen --- .../unit/app/endpoints/test_conversations.py | 25 ++++++++++++++++++- tests/unit/app/endpoints/test_query.py | 18 +++++++------ .../app/endpoints/test_streaming_query.py | 15 +++++------ 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index 6a100289..1ff530e8 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -14,7 +14,6 @@ MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") VALID_CONVERSATION_ID = "123e4567-e89b-12d3-a456-426614174000" -VALID_AGENT_ID = "agent_123" INVALID_CONVERSATION_ID = "invalid-id" @@ -224,6 +223,30 @@ def test_llama_stack_connection_error(self, mocker, setup_configuration): assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE assert "Unable to connect to Llama Stack" in exc_info.value.detail["response"] + def test_llama_stack_not_found_error(self, mocker, setup_configuration): + """Test the endpoint when LlamaStack returns NotFoundError.""" + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + mocker.patch("app.endpoints.conversations.check_suid", return_value=True) + + + # Mock LlamaStackClientHolder to raise NotFoundError + mock_client = mocker.Mock() + mock_client.agents.session.list.side_effect = NotFoundError( + message="Session not found", response=mocker.Mock(request=None), body=None + ) + mock_client_holder = mocker.patch( + "app.endpoints.conversations.LlamaStackClientHolder" + ) + mock_client_holder.return_value.get_client.return_value = mock_client + + with pytest.raises(HTTPException) as exc_info: + get_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert "Conversation not found" in exc_info.value.detail["response"] + assert "could not be retrieved" in exc_info.value.detail["cause"] + assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] + def test_session_retrieve_exception(self, mocker, setup_configuration): """Test the endpoint when session retrieval raises an exception.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index b5cd18e5..81cf9865 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -1081,8 +1081,8 @@ def test_query_endpoint_handler_on_connection_error(mocker): mock_metric.inc.assert_called_once() -def test_get_agent_cache_hit(prepare_agent_mocks, mocker): - """Test get_agent function when agent exists in cache.""" +def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): + """Test get_agent function when agent exists in llama stack.""" mock_client, mock_agent = prepare_agent_mocks mock_client.agents.session.list.return_value = mocker.Mock( data=[{"session_id": "test_session_id"}] @@ -1103,16 +1103,17 @@ def test_get_agent_cache_hit(prepare_agent_mocks, mocker): conversation_id=conversation_id, ) - # Assert cached agent is returned + # Assert the same agent is returned assert result_agent == mock_agent assert result_conversation_id == result_agent.agent_id + assert conversation_id == result_agent.agent_id assert result_session_id == "test_session_id" -def test_get_agent_cache_miss_with_conversation_id( +def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( setup_configuration, prepare_agent_mocks, mocker ): - """Test get_agent function when conversation_id is provided but agent not in cache.""" + """Test get_agent function when conversation_id is provided.""" mock_client, mock_agent = prepare_agent_mocks mock_client.agents.retrieve.side_effect = ValueError( "fake not finding existing agent" @@ -1137,20 +1138,21 @@ def test_get_agent_cache_miss_with_conversation_id( return_value=[mock_mcp_server], ) mocker.patch("app.endpoints.query.configuration", setup_configuration) - - # Call function with conversation_id but no cached agent + conversation_id="non_existent_conversation_id" + # Call function with conversation_id result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", available_input_shields=["shield1"], available_output_shields=["output_shield2"], - conversation_id="non_existent_conversation_id", + conversation_id=conversation_id, ) # Assert new agent is created assert result_agent == mock_agent assert result_conversation_id == result_agent.agent_id + assert conversation_id != result_agent.agent_id assert result_session_id == "new_session_id" # Verify Agent was created with correct parameters diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 6975ce51..b838c42e 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1211,12 +1211,11 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): @pytest.mark.asyncio -async def test_get_agent_cache_hit(prepare_agent_mocks, mocker): - """Test get_agent function when agent exists in cache.""" +async def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): + """Test get_agent function when agent exists in llama stack.""" mock_client, mock_agent = prepare_agent_mocks - # Set up cache with existing agent conversation_id = "test_conversation_id" # Mock Agent class @@ -1231,16 +1230,17 @@ async def test_get_agent_cache_hit(prepare_agent_mocks, mocker): conversation_id=conversation_id, ) - # Assert cached agent is returned + # Assert the same agent is returned assert result_agent == mock_agent assert result_conversation_id == conversation_id + assert conversation_id == mock_agent._agent_id @pytest.mark.asyncio -async def test_get_agent_cache_miss_with_conversation_id( +async def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( setup_configuration, prepare_agent_mocks, mocker ): - """Test get_agent function when conversation_id is provided but agent not in cache.""" + """Test get_agent function when conversation_id is provided but agent not in llama stack.""" mock_client, mock_agent = prepare_agent_mocks mock_client.agents.retrieve.side_effect = ValueError( @@ -1269,7 +1269,7 @@ async def test_get_agent_cache_miss_with_conversation_id( ) mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) - # Call function with conversation_id but no cached agent + # Call function with conversation_id result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", @@ -1282,6 +1282,7 @@ async def test_get_agent_cache_miss_with_conversation_id( # Assert new agent is created assert result_agent == mock_agent assert result_conversation_id == result_agent.agent_id + assert "non_existent_conversation_id" != result_agent.agent_id assert result_session_id == "new_session_id" # Verify Agent was created with correct parameters From d8d183646b45e8eaad2990242aca6bd96bd8b96d Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Tue, 5 Aug 2025 08:30:59 +0300 Subject: [PATCH 17/19] format --- tests/unit/app/endpoints/test_conversations.py | 3 +-- tests/unit/app/endpoints/test_query.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index 1ff530e8..d1df17a0 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -228,7 +228,6 @@ def test_llama_stack_not_found_error(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Mock LlamaStackClientHolder to raise NotFoundError mock_client = mocker.Mock() mock_client.agents.session.list.side_effect = NotFoundError( @@ -246,7 +245,7 @@ def test_llama_stack_not_found_error(self, mocker, setup_configuration): assert "Conversation not found" in exc_info.value.detail["response"] assert "could not be retrieved" in exc_info.value.detail["cause"] assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] - + def test_session_retrieve_exception(self, mocker, setup_configuration): """Test the endpoint when session retrieval raises an exception.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 81cf9865..8443b12b 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -1138,7 +1138,7 @@ def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( return_value=[mock_mcp_server], ) mocker.patch("app.endpoints.query.configuration", setup_configuration) - conversation_id="non_existent_conversation_id" + conversation_id = "non_existent_conversation_id" # Call function with conversation_id result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, From 66daff6813e0aaec001b1c44b99b607d155a4b38 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Tue, 5 Aug 2025 08:33:48 +0300 Subject: [PATCH 18/19] rename session_data --- src/app/endpoints/conversations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index 38cfd25c..6032d01d 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -67,11 +67,11 @@ } -def simplify_session_data(session_dict: dict) -> list[dict[str, Any]]: +def simplify_session_data(session_data: dict) -> list[dict[str, Any]]: """Simplify session data to include only essential conversation information. Args: - session_dict: The full session data dict from llama-stack + session_data: The full session data dict from llama-stack Returns: Simplified session data with only input_messages and output_message per turn @@ -80,7 +80,7 @@ def simplify_session_data(session_dict: dict) -> list[dict[str, Any]]: chat_history = [] # Extract only essential data from each turn - for turn in session_dict.get("turns", []): + for turn in session_data.get("turns", []): # Clean up input messages cleaned_messages = [] for msg in turn.get("input_messages", []): From 179b0278c7786b6bd68dd7e105f58852d4684531 Mon Sep 17 00:00:00 2001 From: Maor Friedman Date: Tue, 5 Aug 2025 10:03:26 +0300 Subject: [PATCH 19/19] lint (one last time?) --- tests/unit/app/endpoints/test_streaming_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index b838c42e..0251e514 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1233,7 +1233,7 @@ async def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): # Assert the same agent is returned assert result_agent == mock_agent assert result_conversation_id == conversation_id - assert conversation_id == mock_agent._agent_id + assert conversation_id == mock_agent._agent_id # pylint: disable=protected-access @pytest.mark.asyncio