Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 6 additions & 30 deletions src/app/endpoints/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
router = APIRouter(tags=["conversations"])
auth_dependency = get_auth_dependency()

conversation_id_to_agent_id: dict[str, str] = {}

conversation_responses: dict[int | str, dict[str, Any]] = {
200: {
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
Expand Down Expand Up @@ -69,21 +67,20 @@
}


def simplify_session_data(session_data: Any) -> list[dict[str, Any]]:
def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
"""Simplify session data to include only essential conversation information.

Args:
session_data: The full session data from llama-stack
session_data: The full session data dict from llama-stack

Returns:
Simplified session data with only input_messages and output_message per turn
"""
session_dict = session_data.model_dump()
# Create simplified structure
chat_history = []

# Extract only essential data from each turn
for turn in session_dict.get("turns", []):
for turn in session_data.get("turns", []):
# Clean up input messages
cleaned_messages = []
for msg in turn.get("input_messages", []):
Expand Down Expand Up @@ -131,25 +128,13 @@ def get_conversation_endpoint_handler(
},
)

agent_id = conversation_id_to_agent_id.get(conversation_id)
if not agent_id:
logger.error("Agent ID not found for conversation %s", conversation_id)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"response": "conversation ID not found",
"cause": f"conversation ID {conversation_id} not found!",
},
)

agent_id = conversation_id
logger.info("Retrieving conversation %s", conversation_id)

try:
client = LlamaStackClientHolder().get_client()

session_data = client.agents.session.retrieve(
agent_id=agent_id, session_id=conversation_id
)
session_data = client.agents.session.list(agent_id=agent_id).data[0]

logger.info("Successfully retrieved conversation %s", conversation_id)

Expand Down Expand Up @@ -211,16 +196,7 @@ def delete_conversation_endpoint_handler(
"cause": f"Conversation ID {conversation_id} is not a valid UUID",
},
)
agent_id = conversation_id_to_agent_id.get(conversation_id)
if not agent_id:
logger.error("Agent ID not found for conversation %s", conversation_id)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"response": "conversation ID not found",
"cause": f"conversation ID {conversation_id} not found!",
},
)
agent_id = conversation_id
logger.info("Deleting conversation %s", conversation_id)

try:
Expand Down
43 changes: 21 additions & 22 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Handler for REST API call to provide answer to query."""

from contextlib import suppress
from datetime import datetime, UTC
import json
import logging
import os
from pathlib import Path
from typing import Any

from cachetools import TTLCache # type: ignore

from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client import APIConnectionError
from llama_stack_client import LlamaStackClient # type: ignore
Expand All @@ -23,7 +22,6 @@

from client import LlamaStackClientHolder
from configuration import configuration
from app.endpoints.conversations import conversation_id_to_agent_id
import metrics
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
from models.requests import QueryRequest, Attachment
Expand All @@ -39,9 +37,6 @@
router = APIRouter(tags=["query"])
auth_dependency = get_auth_dependency()

# Global agent registry to persist agents across requests
_agent_cache: TTLCache[str, Agent] = TTLCache(maxsize=1000, ttl=3600)

query_response: dict[int | str, dict[str, Any]] = {
200: {
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
Expand Down Expand Up @@ -81,16 +76,14 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen
available_output_shields: list[str],
conversation_id: str | None,
no_tools: bool = False,
) -> tuple[Agent, str]:
) -> tuple[Agent, str, str]:
"""Get existing agent or create a new one with session persistence."""
if conversation_id is not None:
agent = _agent_cache.get(conversation_id)
if agent:
logger.debug(
"Reusing existing agent with conversation_id: %s", conversation_id
)
return agent, conversation_id
logger.debug("No existing agent found for conversation_id: %s", conversation_id)
existing_agent_id = None
if conversation_id:
with suppress(ValueError):
existing_agent_id = client.agents.retrieve(
agent_id=conversation_id
).agent_id

logger.debug("Creating new agent")
# TODO(lucasagomes): move to ReActAgent
Expand All @@ -103,12 +96,18 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen
tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
enable_session_persistence=True,
)
conversation_id = agent.create_session(get_suid())
logger.debug("Created new agent and conversation_id: %s", conversation_id)
_agent_cache[conversation_id] = agent
conversation_id_to_agent_id[conversation_id] = agent.agent_id
if existing_agent_id and conversation_id:
orphan_agent_id = agent.agent_id
agent.agent_id = conversation_id
client.agents.delete(agent_id=orphan_agent_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this delete happens before we change the ID of the newly created agent to conversation_id ? Otherwise, this delete could also affect the new agent as it's ID will now match here ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note agent_id=orphan_agent_id, this deletes the newly created agent from the llama-stack server DB

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@umago IDK if this could help explain: #317 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I was on PTO yesterday, reading thru it now. It does makes sense. Thank you

sessions_response = client.agents.session.list(agent_id=conversation_id)
logger.info("session response: %s", sessions_response)
session_id = str(sessions_response.data[0]["session_id"])
else:
conversation_id = agent.agent_id
session_id = agent.create_session(get_suid())

return agent, conversation_id
return agent, conversation_id, session_id


@router.post("/query", responses=query_response)
Expand Down Expand Up @@ -282,7 +281,7 @@ def retrieve_response( # pylint: disable=too-many-locals
if query_request.attachments:
validate_attachments_metadata(query_request.attachments)

agent, conversation_id = get_agent(
agent, conversation_id, session_id = get_agent(
client,
model_id,
system_prompt,
Expand Down Expand Up @@ -326,7 +325,7 @@ def retrieve_response( # pylint: disable=too-many-locals

response = agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
session_id=conversation_id,
session_id=session_id,
documents=query_request.get_documents(),
stream=False,
toolgroups=toolgroups,
Expand Down
44 changes: 22 additions & 22 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Handler for REST API call to provide answer to streaming query."""

import ast
from contextlib import suppress
import json
import re
import logging
from typing import Any, AsyncIterator, Iterator

from cachetools import TTLCache # type: ignore

from llama_stack_client import APIConnectionError
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
from llama_stack_client import AsyncLlamaStackClient # type: ignore
Expand All @@ -31,7 +30,6 @@
from utils.suid import get_suid
from utils.types import GraniteToolParser

from app.endpoints.conversations import conversation_id_to_agent_id
from app.endpoints.query import (
get_rag_toolgroups,
is_input_shield,
Expand All @@ -46,9 +44,6 @@
router = APIRouter(tags=["streaming_query"])
auth_dependency = get_auth_dependency()

# Global agent registry to persist agents across requests
_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600)


# # pylint: disable=R0913,R0917
async def get_agent(
Expand All @@ -59,16 +54,13 @@ async def get_agent(
available_output_shields: list[str],
conversation_id: str | None,
no_tools: bool = False,
) -> tuple[AsyncAgent, str]:
) -> tuple[AsyncAgent, str, str]:
"""Get existing agent or create a new one with session persistence."""
if conversation_id is not None:
agent = _agent_cache.get(conversation_id)
if agent:
logger.debug(
"Reusing existing agent with conversation_id: %s", conversation_id
)
return agent, conversation_id
logger.debug("No existing agent found for conversation_id: %s", conversation_id)
existing_agent_id = None
if conversation_id:
with suppress(ValueError):
agent_response = await client.agents.retrieve(agent_id=conversation_id)
existing_agent_id = agent_response.agent_id

logger.debug("Creating new agent")
agent = AsyncAgent(
Expand All @@ -80,11 +72,19 @@ async def get_agent(
tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
enable_session_persistence=True,
)
conversation_id = await agent.create_session(get_suid())
logger.debug("Created new agent and conversation_id: %s", conversation_id)
_agent_cache[conversation_id] = agent
conversation_id_to_agent_id[conversation_id] = agent.agent_id
return agent, conversation_id

if existing_agent_id and conversation_id:
orphan_agent_id = agent.agent_id
agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access
await client.agents.delete(agent_id=orphan_agent_id)
sessions_response = await client.agents.session.list(agent_id=conversation_id)
logger.info("session response: %s", sessions_response)
session_id = str(sessions_response.data[0]["session_id"])
else:
conversation_id = agent.agent_id
session_id = await agent.create_session(get_suid())

return agent, conversation_id, session_id


METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")
Expand Down Expand Up @@ -526,7 +526,7 @@ async def retrieve_response(
if query_request.attachments:
validate_attachments_metadata(query_request.attachments)

agent, conversation_id = await get_agent(
agent, conversation_id, session_id = await get_agent(
client,
model_id,
system_prompt,
Expand Down Expand Up @@ -576,7 +576,7 @@ async def retrieve_response(
logger.debug("Session ID: %s", conversation_id)
response = await agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
session_id=conversation_id,
session_id=session_id,
documents=query_request.get_documents(),
stream=True,
toolgroups=toolgroups,
Expand Down
Loading
Loading