-
Notifications
You must be signed in to change notification settings - Fork 52
refactor conversation to agent and conversation mapping #317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e1d5c40
20ead26
26e2de5
a11fcf4
89e8678
3c368c0
f360f03
f472814
3318bcd
abb1384
92cf14e
50b652a
20b1df5
baab097
8b26367
0d5306e
d8d1836
66daff6
179b027
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,13 @@ | ||
| """Handler for REST API call to provide answer to query.""" | ||
|
|
||
| from contextlib import suppress | ||
| from datetime import datetime, UTC | ||
| import json | ||
| import logging | ||
| import os | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| from cachetools import TTLCache # type: ignore | ||
|
|
||
| from llama_stack_client.lib.agents.agent import Agent | ||
| from llama_stack_client import APIConnectionError | ||
| from llama_stack_client import LlamaStackClient # type: ignore | ||
|
|
@@ -23,7 +22,6 @@ | |
|
|
||
| from client import LlamaStackClientHolder | ||
| from configuration import configuration | ||
| from app.endpoints.conversations import conversation_id_to_agent_id | ||
| import metrics | ||
| from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse | ||
| from models.requests import QueryRequest, Attachment | ||
|
|
@@ -39,9 +37,6 @@ | |
| router = APIRouter(tags=["query"]) | ||
| auth_dependency = get_auth_dependency() | ||
|
|
||
| # Global agent registry to persist agents across requests | ||
| _agent_cache: TTLCache[str, Agent] = TTLCache(maxsize=1000, ttl=3600) | ||
|
|
||
| query_response: dict[int | str, dict[str, Any]] = { | ||
| 200: { | ||
| "conversation_id": "123e4567-e89b-12d3-a456-426614174000", | ||
|
|
@@ -81,16 +76,14 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen | |
| available_output_shields: list[str], | ||
| conversation_id: str | None, | ||
| no_tools: bool = False, | ||
| ) -> tuple[Agent, str]: | ||
| ) -> tuple[Agent, str, str]: | ||
| """Get existing agent or create a new one with session persistence.""" | ||
| if conversation_id is not None: | ||
| agent = _agent_cache.get(conversation_id) | ||
| if agent: | ||
| logger.debug( | ||
| "Reusing existing agent with conversation_id: %s", conversation_id | ||
| ) | ||
| return agent, conversation_id | ||
| logger.debug("No existing agent found for conversation_id: %s", conversation_id) | ||
| existing_agent_id = None | ||
| if conversation_id: | ||
| with suppress(ValueError): | ||
| existing_agent_id = client.agents.retrieve( | ||
| agent_id=conversation_id | ||
| ).agent_id | ||
|
|
||
| logger.debug("Creating new agent") | ||
| # TODO(lucasagomes): move to ReActAgent | ||
|
|
@@ -103,12 +96,18 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen | |
| tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), | ||
| enable_session_persistence=True, | ||
| ) | ||
| conversation_id = agent.create_session(get_suid()) | ||
| logger.debug("Created new agent and conversation_id: %s", conversation_id) | ||
| _agent_cache[conversation_id] = agent | ||
| conversation_id_to_agent_id[conversation_id] = agent.agent_id | ||
| if existing_agent_id and conversation_id: | ||
| orphan_agent_id = agent.agent_id | ||
| agent.agent_id = conversation_id | ||
| client.agents.delete(agent_id=orphan_agent_id) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this delete happens before we change the ID of the newly created agent to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @umago IDK if this could help explain: #317 (comment)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I was on PTO yesterday, reading thru it now. It does makes sense. Thank you |
||
| sessions_response = client.agents.session.list(agent_id=conversation_id) | ||
| logger.info("session response: %s", sessions_response) | ||
| session_id = str(sessions_response.data[0]["session_id"]) | ||
| else: | ||
| conversation_id = agent.agent_id | ||
maorfr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| session_id = agent.create_session(get_suid()) | ||
maorfr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return agent, conversation_id | ||
| return agent, conversation_id, session_id | ||
|
|
||
|
|
||
| @router.post("/query", responses=query_response) | ||
|
|
@@ -282,7 +281,7 @@ def retrieve_response( # pylint: disable=too-many-locals | |
| if query_request.attachments: | ||
| validate_attachments_metadata(query_request.attachments) | ||
|
|
||
| agent, conversation_id = get_agent( | ||
| agent, conversation_id, session_id = get_agent( | ||
| client, | ||
| model_id, | ||
| system_prompt, | ||
|
|
@@ -326,7 +325,7 @@ def retrieve_response( # pylint: disable=too-many-locals | |
|
|
||
| response = agent.create_turn( | ||
| messages=[UserMessage(role="user", content=query_request.query)], | ||
| session_id=conversation_id, | ||
| session_id=session_id, | ||
| documents=query_request.get_documents(), | ||
| stream=False, | ||
| toolgroups=toolgroups, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.