Skip to content

Commit 5152a7d

Browse files
committed
refactor conversation to agent and conversation mapping
1 parent 4b94249 commit 5152a7d

File tree

6 files changed

+18
-112
lines changed

6 files changed

+18
-112
lines changed

src/app/endpoints/conversations.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
router = APIRouter(tags=["conversations"])
1919
auth_dependency = get_auth_dependency()
2020

21-
conversation_id_to_agent_id: dict[str, str] = {}
22-
2321
conversation_responses: dict[int | str, dict[str, Any]] = {
2422
200: {
2523
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
@@ -126,17 +124,7 @@ def get_conversation_endpoint_handler(
126124
},
127125
)
128126

129-
agent_id = conversation_id_to_agent_id.get(conversation_id)
130-
if not agent_id:
131-
logger.error("Agent ID not found for conversation %s", conversation_id)
132-
raise HTTPException(
133-
status_code=status.HTTP_404_NOT_FOUND,
134-
detail={
135-
"response": "conversation ID not found",
136-
"cause": f"conversation ID {conversation_id} not found!",
137-
},
138-
)
139-
127+
agent_id = conversation_id
140128
logger.info("Retrieving conversation %s", conversation_id)
141129

142130
try:
@@ -206,16 +194,7 @@ def delete_conversation_endpoint_handler(
206194
"cause": f"Conversation ID {conversation_id} is not a valid UUID",
207195
},
208196
)
209-
agent_id = conversation_id_to_agent_id.get(conversation_id)
210-
if not agent_id:
211-
logger.error("Agent ID not found for conversation %s", conversation_id)
212-
raise HTTPException(
213-
status_code=status.HTTP_404_NOT_FOUND,
214-
detail={
215-
"response": "conversation ID not found",
216-
"cause": f"conversation ID {conversation_id} not found!",
217-
},
218-
)
197+
agent_id = conversation_id
219198
logger.info("Deleting conversation %s", conversation_id)
220199

221200
try:

src/app/endpoints/query.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Handler for REST API call to provide answer to query."""
22

33
from datetime import datetime, UTC
4+
from functools import lru_cache
45
import json
56
import logging
67
import os
@@ -23,7 +24,6 @@
2324

2425
from client import LlamaStackClientHolder
2526
from configuration import configuration
26-
from app.endpoints.conversations import conversation_id_to_agent_id
2727
import metrics
2828
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
2929
from models.requests import QueryRequest, Attachment
@@ -39,9 +39,6 @@
3939
router = APIRouter(tags=["query"])
4040
auth_dependency = get_auth_dependency()
4141

42-
# Global agent registry to persist agents across requests
43-
_agent_cache: TTLCache[str, Agent] = TTLCache(maxsize=1000, ttl=3600)
44-
4542
query_response: dict[int | str, dict[str, Any]] = {
4643
200: {
4744
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
@@ -73,6 +70,7 @@ def is_transcripts_enabled() -> bool:
7370
return configuration.user_data_collection_configuration.transcripts_enabled
7471

7572

73+
@lru_cache
7674
def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
7775
client: LlamaStackClient,
7876
model_id: str,
@@ -82,15 +80,6 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen
8280
conversation_id: str | None,
8381
) -> tuple[Agent, str]:
8482
"""Get existing agent or create a new one with session persistence."""
85-
if conversation_id is not None:
86-
agent = _agent_cache.get(conversation_id)
87-
if agent:
88-
logger.debug(
89-
"Reusing existing agent with conversation_id: %s", conversation_id
90-
)
91-
return agent, conversation_id
92-
logger.debug("No existing agent found for conversation_id: %s", conversation_id)
93-
9483
logger.debug("Creating new agent")
9584
# TODO(lucasagomes): move to ReActAgent
9685
agent = Agent(
@@ -102,10 +91,11 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen
10291
tool_parser=GraniteToolParser.get_parser(model_id),
10392
enable_session_persistence=True,
10493
)
105-
conversation_id = agent.create_session(get_suid())
106-
logger.debug("Created new agent and conversation_id: %s", conversation_id)
107-
_agent_cache[conversation_id] = agent
108-
conversation_id_to_agent_id[conversation_id] = agent.agent_id
94+
if conversation_id:
95+
agent.agent_id = conversation_id
96+
else:
97+
agent.create_session(get_suid())
98+
conversation_id = agent.agent_id
10999

110100
return agent, conversation_id
111101

src/app/endpoints/streaming_query.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Handler for REST API call to provide answer to streaming query."""
22

3+
from functools import lru_cache
34
import json
45
import logging
56
import re
@@ -31,7 +32,6 @@
3132
from utils.suid import get_suid
3233
from utils.types import GraniteToolParser
3334

34-
from app.endpoints.conversations import conversation_id_to_agent_id
3535
from app.endpoints.query import (
3636
get_rag_toolgroups,
3737
is_input_shield,
@@ -46,11 +46,9 @@
4646
router = APIRouter(tags=["streaming_query"])
4747
auth_dependency = get_auth_dependency()
4848

49-
# Global agent registry to persist agents across requests
50-
_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600)
51-
5249

5350
# # pylint: disable=R0913,R0917
51+
@lru_cache
5452
async def get_agent(
5553
client: AsyncLlamaStackClient,
5654
model_id: str,
@@ -60,15 +58,6 @@ async def get_agent(
6058
conversation_id: str | None,
6159
) -> tuple[AsyncAgent, str]:
6260
"""Get existing agent or create a new one with session persistence."""
63-
if conversation_id is not None:
64-
agent = _agent_cache.get(conversation_id)
65-
if agent:
66-
logger.debug(
67-
"Reusing existing agent with conversation_id: %s", conversation_id
68-
)
69-
return agent, conversation_id
70-
logger.debug("No existing agent found for conversation_id: %s", conversation_id)
71-
7261
logger.debug("Creating new agent")
7362
agent = AsyncAgent(
7463
client, # type: ignore[arg-type]
@@ -79,10 +68,13 @@ async def get_agent(
7968
tool_parser=GraniteToolParser.get_parser(model_id),
8069
enable_session_persistence=True,
8170
)
82-
conversation_id = await agent.create_session(get_suid())
83-
logger.debug("Created new agent and conversation_id: %s", conversation_id)
84-
_agent_cache[conversation_id] = agent
85-
conversation_id_to_agent_id[conversation_id] = agent.agent_id
71+
72+
if conversation_id:
73+
agent._agent_id = conversation_id
74+
else:
75+
conversation_id = agent.agent_id
76+
await agent.create_session(get_suid())
77+
8678
return agent, conversation_id
8779

8880

tests/unit/app/endpoints/test_conversations.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from app.endpoints.conversations import (
88
get_conversation_endpoint_handler,
99
delete_conversation_endpoint_handler,
10-
conversation_id_to_agent_id,
1110
simplify_session_data,
1211
)
1312
from models.responses import ConversationResponse, ConversationDeleteResponse
@@ -48,16 +47,6 @@ def setup_configuration_fixture():
4847
return cfg
4948

5049

51-
@pytest.fixture(autouse=True)
52-
def setup_conversation_mapping():
53-
"""Set up and clean up the conversation ID to agent ID mapping."""
54-
# Clear the mapping before each test
55-
conversation_id_to_agent_id.clear()
56-
yield
57-
# Clean up after each test
58-
conversation_id_to_agent_id.clear()
59-
60-
6150
@pytest.fixture(name="mock_session_data")
6251
def mock_session_data_fixture():
6352
"""Create mock session data for testing."""
@@ -243,9 +232,6 @@ def test_llama_stack_connection_error(self, mocker, setup_configuration):
243232
mocker.patch("app.endpoints.conversations.configuration", setup_configuration)
244233
mocker.patch("app.endpoints.conversations.check_suid", return_value=True)
245234

246-
# Set up conversation mapping
247-
conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID
248-
249235
# Mock LlamaStackClientHolder to raise APIConnectionError
250236
mock_client = mocker.Mock()
251237
mock_client.agents.session.retrieve.side_effect = APIConnectionError(
@@ -268,9 +254,6 @@ def test_llama_stack_not_found_error(self, mocker, setup_configuration):
268254
mocker.patch("app.endpoints.conversations.configuration", setup_configuration)
269255
mocker.patch("app.endpoints.conversations.check_suid", return_value=True)
270256

271-
# Set up conversation mapping
272-
conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID
273-
274257
# Mock LlamaStackClientHolder to raise NotFoundError
275258
mock_client = mocker.Mock()
276259
mock_client.agents.session.retrieve.side_effect = NotFoundError(
@@ -294,9 +277,6 @@ def test_session_retrieve_exception(self, mocker, setup_configuration):
294277
mocker.patch("app.endpoints.conversations.configuration", setup_configuration)
295278
mocker.patch("app.endpoints.conversations.check_suid", return_value=True)
296279

297-
# Set up conversation mapping
298-
conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID
299-
300280
# Mock LlamaStackClientHolder to raise a general exception
301281
mock_client = mocker.Mock()
302282
mock_client.agents.session.retrieve.side_effect = Exception(
@@ -323,9 +303,6 @@ def test_successful_conversation_retrieval(
323303
mocker.patch("app.endpoints.conversations.configuration", setup_configuration)
324304
mocker.patch("app.endpoints.conversations.check_suid", return_value=True)
325305

326-
# Set up conversation mapping
327-
conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID
328-
329306
# Mock session data with model_dump method
330307
mock_session_obj = mocker.Mock()
331308
mock_session_obj.model_dump.return_value = mock_session_data
@@ -394,9 +371,6 @@ def test_llama_stack_connection_error(self, mocker, setup_configuration):
394371
mocker.patch("app.endpoints.conversations.configuration", setup_configuration)
395372
mocker.patch("app.endpoints.conversations.check_suid", return_value=True)
396373

397-
# Set up conversation mapping
398-
conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID
399-
400374
# Mock LlamaStackClientHolder to raise APIConnectionError
401375
mock_client = mocker.Mock()
402376
mock_client.agents.session.delete.side_effect = APIConnectionError(request=None)
@@ -416,9 +390,6 @@ def test_llama_stack_not_found_error(self, mocker, setup_configuration):
416390
mocker.patch("app.endpoints.conversations.configuration", setup_configuration)
417391
mocker.patch("app.endpoints.conversations.check_suid", return_value=True)
418392

419-
# Set up conversation mapping
420-
conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID
421-
422393
# Mock LlamaStackClientHolder to raise NotFoundError
423394
mock_client = mocker.Mock()
424395
mock_client.agents.session.delete.side_effect = NotFoundError(
@@ -442,9 +413,6 @@ def test_session_deletion_exception(self, mocker, setup_configuration):
442413
mocker.patch("app.endpoints.conversations.configuration", setup_configuration)
443414
mocker.patch("app.endpoints.conversations.check_suid", return_value=True)
444415

445-
# Set up conversation mapping
446-
conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID
447-
448416
# Mock LlamaStackClientHolder to raise a general exception
449417
mock_client = mocker.Mock()
450418
mock_client.agents.session.delete.side_effect = Exception(
@@ -470,9 +438,6 @@ def test_successful_conversation_deletion(self, mocker, setup_configuration):
470438
mocker.patch("app.endpoints.conversations.configuration", setup_configuration)
471439
mocker.patch("app.endpoints.conversations.check_suid", return_value=True)
472440

473-
# Set up conversation mapping
474-
conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID
475-
476441
# Mock LlamaStackClientHolder
477442
mock_client = mocker.Mock()
478443
mock_client.agents.session.delete.return_value = None # Successful deletion

tests/unit/app/endpoints/test_query.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
store_transcript,
2121
get_rag_toolgroups,
2222
get_agent,
23-
_agent_cache,
2423
)
2524

2625
from models.requests import QueryRequest, Attachment
@@ -65,8 +64,6 @@ def prepare_agent_mocks_fixture(mocker):
6564
mock_agent = mocker.Mock()
6665
mock_agent.create_turn.return_value.steps = []
6766
yield mock_client, mock_agent
68-
# cleanup agent cache after tests
69-
_agent_cache.clear()
7067

7168

7269
def test_query_endpoint_handler_configuration_not_loaded(mocker):
@@ -1065,7 +1062,6 @@ def test_get_agent_cache_hit(prepare_agent_mocks):
10651062

10661063
# Set up cache with existing agent
10671064
conversation_id = "test_conversation_id"
1068-
_agent_cache[conversation_id] = mock_agent
10691065

10701066
result_agent, result_conversation_id = get_agent(
10711067
client=mock_client,
@@ -1132,9 +1128,6 @@ def test_get_agent_cache_miss_with_conversation_id(
11321128
enable_session_persistence=True,
11331129
)
11341130

1135-
# Verify agent was stored in cache
1136-
assert _agent_cache["new_session_id"] == mock_agent
1137-
11381131

11391132
def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, mocker):
11401133
"""Test get_agent function when conversation_id is None."""
@@ -1185,9 +1178,6 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks,
11851178
enable_session_persistence=True,
11861179
)
11871180

1188-
# Verify agent was stored in cache
1189-
assert _agent_cache["new_session_id"] == mock_agent
1190-
11911181

11921182
def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocker):
11931183
"""Test get_agent function with empty shields list."""

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
retrieve_response,
4242
stream_build_event,
4343
get_agent,
44-
_agent_cache,
4544
)
4645
from models.requests import QueryRequest, Attachment
4746
from models.config import ModelContextProtocolServer
@@ -109,8 +108,6 @@ def prepare_agent_mocks_fixture(mocker):
109108
mock_client = mocker.AsyncMock()
110109
mock_agent = mocker.AsyncMock()
111110
yield mock_client, mock_agent
112-
# cleanup agent cache after tests
113-
_agent_cache.clear()
114111

115112

116113
@pytest.mark.asyncio
@@ -1213,7 +1210,6 @@ async def test_get_agent_cache_hit(prepare_agent_mocks):
12131210

12141211
# Set up cache with existing agent
12151212
conversation_id = "test_conversation_id"
1216-
_agent_cache[conversation_id] = mock_agent
12171213

12181214
result_agent, result_conversation_id = await get_agent(
12191215
client=mock_client,
@@ -1284,9 +1280,6 @@ async def test_get_agent_cache_miss_with_conversation_id(
12841280
enable_session_persistence=True,
12851281
)
12861282

1287-
# Verify agent was stored in cache
1288-
assert _agent_cache["new_session_id"] == mock_agent
1289-
12901283

12911284
@pytest.mark.asyncio
12921285
async def test_get_agent_no_conversation_id(
@@ -1343,9 +1336,6 @@ async def test_get_agent_no_conversation_id(
13431336
enable_session_persistence=True,
13441337
)
13451338

1346-
# Verify agent was stored in cache
1347-
assert _agent_cache["new_session_id"] == mock_agent
1348-
13491339

13501340
@pytest.mark.asyncio
13511341
async def test_get_agent_empty_shields(

0 commit comments

Comments
 (0)