88from pathlib import Path
99from typing import Any
1010
11- from llama_stack_client .lib .agents .agent import Agent
11+ from llama_stack_client .lib .agents .agent import AsyncAgent
1212from llama_stack_client import APIConnectionError
13- from llama_stack_client import LlamaStackClient # type: ignore
13+ from llama_stack_client import AsyncLlamaStackClient # type: ignore
1414from llama_stack_client .types import UserMessage , Shield # type: ignore
1515from llama_stack_client .types .agents .turn_create_params import (
1616 ToolgroupAgentToolGroupWithArgs ,
2020
2121from fastapi import APIRouter , HTTPException , status , Depends
2222
23- from client import LlamaStackClientHolder
23+ from client import AsyncLlamaStackClientHolder
2424from configuration import configuration
2525import metrics
2626from models .responses import QueryResponse , UnauthorizedResponse , ForbiddenResponse
@@ -68,26 +68,26 @@ def is_transcripts_enabled() -> bool:
6868 return configuration .user_data_collection_configuration .transcripts_enabled
6969
7070
71- def get_agent ( # pylint: disable=too-many-arguments,too-many-positional-arguments
72- client : LlamaStackClient ,
71+ async def get_agent ( # pylint: disable=too-many-arguments,too-many-positional-arguments
72+ client : AsyncLlamaStackClient ,
7373 model_id : str ,
7474 system_prompt : str ,
7575 available_input_shields : list [str ],
7676 available_output_shields : list [str ],
7777 conversation_id : str | None ,
7878 no_tools : bool = False ,
79- ) -> tuple [Agent , str , str ]:
79+ ) -> tuple [AsyncAgent , str , str ]:
8080 """Get existing agent or create a new one with session persistence."""
8181 existing_agent_id = None
8282 if conversation_id :
8383 with suppress (ValueError ):
84- existing_agent_id = client . agents . retrieve (
85- agent_id = conversation_id
84+ existing_agent_id = (
85+ await client . agents . retrieve ( agent_id = conversation_id )
8686 ).agent_id
8787
8888 logger .debug ("Creating new agent" )
8989 # TODO(lucasagomes): move to ReActAgent
90- agent = Agent (
90+ agent = AsyncAgent (
9191 client ,
9292 model = model_id ,
9393 instructions = system_prompt ,
@@ -99,19 +99,19 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen
9999 if existing_agent_id and conversation_id :
100100 orphan_agent_id = agent .agent_id
101101 agent .agent_id = conversation_id
102- client .agents .delete (agent_id = orphan_agent_id )
103- sessions_response = client .agents .session .list (agent_id = conversation_id )
102+ await client .agents .delete (agent_id = orphan_agent_id )
103+ sessions_response = await client .agents .session .list (agent_id = conversation_id )
104104 logger .info ("session response: %s" , sessions_response )
105105 session_id = str (sessions_response .data [0 ]["session_id" ])
106106 else :
107107 conversation_id = agent .agent_id
108- session_id = agent .create_session (get_suid ())
108+ session_id = await agent .create_session (get_suid ())
109109
110110 return agent , conversation_id , session_id
111111
112112
113113@router .post ("/query" , responses = query_response )
114- def query_endpoint_handler (
114+ async def query_endpoint_handler (
115115 query_request : QueryRequest ,
116116 auth : Any = Depends (auth_dependency ),
117117 mcp_headers : dict [str , dict [str , str ]] = Depends (mcp_headers_dependency ),
@@ -126,11 +126,11 @@ def query_endpoint_handler(
126126
127127 try :
128128 # try to get Llama Stack client
129- client = LlamaStackClientHolder ().get_client ()
129+ client = AsyncLlamaStackClientHolder ().get_client ()
130130 model_id , provider_id = select_model_and_provider_id (
131- client .models .list (), query_request
131+ await client .models .list (), query_request
132132 )
133- response , conversation_id = retrieve_response (
133+ response , conversation_id = await retrieve_response (
134134 client ,
135135 model_id ,
136136 query_request ,
@@ -250,19 +250,21 @@ def is_input_shield(shield: Shield) -> bool:
250250 return _is_inout_shield (shield ) or not is_output_shield (shield )
251251
252252
253- def retrieve_response ( # pylint: disable=too-many-locals
254- client : LlamaStackClient ,
253+ async def retrieve_response ( # pylint: disable=too-many-locals
254+ client : AsyncLlamaStackClient ,
255255 model_id : str ,
256256 query_request : QueryRequest ,
257257 token : str ,
258258 mcp_headers : dict [str , dict [str , str ]] | None = None ,
259259) -> tuple [str , str ]:
260260 """Retrieve response from LLMs and agents."""
261261 available_input_shields = [
262- shield .identifier for shield in filter (is_input_shield , client .shields .list ())
262+ shield .identifier
263+ for shield in filter (is_input_shield , await client .shields .list ())
263264 ]
264265 available_output_shields = [
265- shield .identifier for shield in filter (is_output_shield , client .shields .list ())
266+ shield .identifier
267+ for shield in filter (is_output_shield , await client .shields .list ())
266268 ]
267269 if not available_input_shields and not available_output_shields :
268270 logger .info ("No available shields. Disabling safety" )
@@ -281,7 +283,7 @@ def retrieve_response( # pylint: disable=too-many-locals
281283 if query_request .attachments :
282284 validate_attachments_metadata (query_request .attachments )
283285
284- agent , conversation_id , session_id = get_agent (
286+ agent , conversation_id , session_id = await get_agent (
285287 client ,
286288 model_id ,
287289 system_prompt ,
@@ -315,15 +317,17 @@ def retrieve_response( # pylint: disable=too-many-locals
315317 ),
316318 }
317319
318- vector_db_ids = [vector_db .identifier for vector_db in client .vector_dbs .list ()]
320+ vector_db_ids = [
321+ vector_db .identifier for vector_db in await client .vector_dbs .list ()
322+ ]
319323 toolgroups = (get_rag_toolgroups (vector_db_ids ) or []) + [
320324 mcp_server .name for mcp_server in configuration .mcp_servers
321325 ]
322326 # Convert empty list to None for consistency with existing behavior
323327 if not toolgroups :
324328 toolgroups = None
325329
326- response = agent .create_turn (
330+ response = await agent .create_turn (
327331 messages = [UserMessage (role = "user" , content = query_request .query )],
328332 session_id = session_id ,
329333 documents = query_request .get_documents (),
0 commit comments