Skip to content

Commit 9a91944

Browse files
committed
feat: Convert lightspeed-core to async architecture
- Migrate endpoints from sync to async handlers - Remove legacy sync client infrastructure - Update unit tests for async compatibility This resolves blocking behavior in all endpoints except streaming_query which was already async, enabling proper concurrent request handling. Signed-off-by: Eran Cohen <[email protected]>
1 parent a3b530d commit 9a91944

File tree

18 files changed

+313
-381
lines changed

18 files changed

+313
-381
lines changed

scripts/generate_openapi_schema.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
# it is needed to read proper configuration in order to start the app to generate schema
1111
from configuration import configuration
1212

13-
from client import LlamaStackClientHolder
13+
from client import AsyncLlamaStackClientHolder
1414

1515
cfg_file = "lightspeed-stack.yaml"
1616
configuration.load_configuration(cfg_file)
1717

1818
# Llama Stack client needs to be loaded before REST API is fully initialized
19-
LlamaStackClientHolder().load(configuration.configuration.llama_stack)
19+
import asyncio # noqa: E402
20+
21+
asyncio.run(AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack))
2022

2123
from app.main import app # noqa: E402 pylint: disable=C0413
2224

src/app/endpoints/authorized.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""Handler for REST API call to authorized endpoint."""
22

3-
import asyncio
43
import logging
54
from typing import Any
65

7-
from fastapi import APIRouter, Request
6+
from fastapi import APIRouter, Depends
87

98
from auth import get_auth_dependency
109
from models.responses import AuthorizedResponse, UnauthorizedResponse, ForbiddenResponse
@@ -31,8 +30,10 @@
3130

3231

3332
@router.post("/authorized", responses=authorized_responses)
34-
def authorized_endpoint_handler(_request: Request) -> AuthorizedResponse:
33+
async def authorized_endpoint_handler(
34+
auth: Any = Depends(auth_dependency),
35+
) -> AuthorizedResponse:
3536
"""Handle request to the /authorized endpoint."""
3637
# Ignore the user token, we should not return it in the response
37-
user_id, user_name, _ = asyncio.run(auth_dependency(_request))
38+
user_id, user_name, _ = auth
3839
return AuthorizedResponse(user_id=user_id, username=user_name)

src/app/endpoints/conversations.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from fastapi import APIRouter, HTTPException, status, Depends
99

10-
from client import LlamaStackClientHolder
10+
from client import AsyncLlamaStackClientHolder
1111
from configuration import configuration
1212
from models.responses import ConversationResponse, ConversationDeleteResponse
1313
from auth import get_auth_dependency
@@ -110,7 +110,7 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
110110

111111

112112
@router.get("/conversations/{conversation_id}", responses=conversation_responses)
113-
def get_conversation_endpoint_handler(
113+
async def get_conversation_endpoint_handler(
114114
conversation_id: str,
115115
_auth: Any = Depends(auth_dependency),
116116
) -> ConversationResponse:
@@ -132,9 +132,9 @@ def get_conversation_endpoint_handler(
132132
logger.info("Retrieving conversation %s", conversation_id)
133133

134134
try:
135-
client = LlamaStackClientHolder().get_client()
135+
client = AsyncLlamaStackClientHolder().get_client()
136136

137-
session_data = client.agents.session.list(agent_id=agent_id).data[0]
137+
session_data = (await client.agents.session.list(agent_id=agent_id)).data[0]
138138

139139
logger.info("Successfully retrieved conversation %s", conversation_id)
140140

@@ -179,7 +179,7 @@ def get_conversation_endpoint_handler(
179179
@router.delete(
180180
"/conversations/{conversation_id}", responses=conversation_delete_responses
181181
)
182-
def delete_conversation_endpoint_handler(
182+
async def delete_conversation_endpoint_handler(
183183
conversation_id: str,
184184
_auth: Any = Depends(auth_dependency),
185185
) -> ConversationDeleteResponse:
@@ -201,10 +201,12 @@ def delete_conversation_endpoint_handler(
201201

202202
try:
203203
# Get Llama Stack client
204-
client = LlamaStackClientHolder().get_client()
204+
client = AsyncLlamaStackClientHolder().get_client()
205205
# Delete session using the conversation_id as session_id
206206
# In this implementation, conversation_id and session_id are the same
207-
client.agents.session.delete(agent_id=agent_id, session_id=conversation_id)
207+
await client.agents.session.delete(
208+
agent_id=agent_id, session_id=conversation_id
209+
)
208210

209211
logger.info("Successfully deleted conversation %s", conversation_id)
210212

src/app/endpoints/models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from llama_stack_client import APIConnectionError
77
from fastapi import APIRouter, HTTPException, Request, status
88

9-
from client import LlamaStackClientHolder
9+
from client import AsyncLlamaStackClientHolder
1010
from configuration import configuration
1111
from models.responses import ModelsResponse
1212
from utils.endpoints import check_configuration_loaded
@@ -43,7 +43,7 @@
4343

4444

4545
@router.get("/models", responses=models_responses)
46-
def models_endpoint_handler(_request: Request) -> ModelsResponse:
46+
async def models_endpoint_handler(_request: Request) -> ModelsResponse:
4747
"""Handle requests to the /models endpoint."""
4848
check_configuration_loaded(configuration)
4949

@@ -52,9 +52,9 @@ def models_endpoint_handler(_request: Request) -> ModelsResponse:
5252

5353
try:
5454
# try to get Llama Stack client
55-
client = LlamaStackClientHolder().get_client()
55+
client = AsyncLlamaStackClientHolder().get_client()
5656
# retrieve models
57-
models = client.models.list()
57+
models = await client.models.list()
5858
m = [dict(m) for m in models]
5959
return ModelsResponse(models=m)
6060

src/app/endpoints/query.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from pathlib import Path
99
from typing import Any
1010

11-
from llama_stack_client.lib.agents.agent import Agent
11+
from llama_stack_client.lib.agents.agent import AsyncAgent
1212
from llama_stack_client import APIConnectionError
13-
from llama_stack_client import LlamaStackClient # type: ignore
13+
from llama_stack_client import AsyncLlamaStackClient # type: ignore
1414
from llama_stack_client.types import UserMessage, Shield # type: ignore
1515
from llama_stack_client.types.agents.turn_create_params import (
1616
ToolgroupAgentToolGroupWithArgs,
@@ -20,7 +20,7 @@
2020

2121
from fastapi import APIRouter, HTTPException, status, Depends
2222

23-
from client import LlamaStackClientHolder
23+
from client import AsyncLlamaStackClientHolder
2424
from configuration import configuration
2525
import metrics
2626
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
@@ -68,27 +68,27 @@ def is_transcripts_enabled() -> bool:
6868
return configuration.user_data_collection_configuration.transcripts_enabled
6969

7070

71-
def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
72-
client: LlamaStackClient,
71+
async def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
72+
client: AsyncLlamaStackClient,
7373
model_id: str,
7474
system_prompt: str,
7575
available_input_shields: list[str],
7676
available_output_shields: list[str],
7777
conversation_id: str | None,
7878
no_tools: bool = False,
79-
) -> tuple[Agent, str, str]:
79+
) -> tuple[AsyncAgent, str, str]:
8080
"""Get existing agent or create a new one with session persistence."""
8181
existing_agent_id = None
8282
if conversation_id:
8383
with suppress(ValueError):
84-
existing_agent_id = client.agents.retrieve(
85-
agent_id=conversation_id
84+
existing_agent_id = (
85+
await client.agents.retrieve(agent_id=conversation_id)
8686
).agent_id
8787

8888
logger.debug("Creating new agent")
8989
# TODO(lucasagomes): move to ReActAgent
90-
agent = Agent(
91-
client,
90+
agent = AsyncAgent(
91+
client, # type: ignore[arg-type]
9292
model=model_id,
9393
instructions=system_prompt,
9494
input_shields=available_input_shields if available_input_shields else [],
@@ -98,20 +98,20 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen
9898
)
9999
if existing_agent_id and conversation_id:
100100
orphan_agent_id = agent.agent_id
101-
agent.agent_id = conversation_id
102-
client.agents.delete(agent_id=orphan_agent_id)
103-
sessions_response = client.agents.session.list(agent_id=conversation_id)
101+
agent.agent_id = conversation_id # type: ignore[misc]
102+
await client.agents.delete(agent_id=orphan_agent_id)
103+
sessions_response = await client.agents.session.list(agent_id=conversation_id)
104104
logger.info("session response: %s", sessions_response)
105105
session_id = str(sessions_response.data[0]["session_id"])
106106
else:
107107
conversation_id = agent.agent_id
108-
session_id = agent.create_session(get_suid())
108+
session_id = await agent.create_session(get_suid())
109109

110110
return agent, conversation_id, session_id
111111

112112

113113
@router.post("/query", responses=query_response)
114-
def query_endpoint_handler(
114+
async def query_endpoint_handler(
115115
query_request: QueryRequest,
116116
auth: Any = Depends(auth_dependency),
117117
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
@@ -126,11 +126,11 @@ def query_endpoint_handler(
126126

127127
try:
128128
# try to get Llama Stack client
129-
client = LlamaStackClientHolder().get_client()
129+
client = AsyncLlamaStackClientHolder().get_client()
130130
model_id, provider_id = select_model_and_provider_id(
131-
client.models.list(), query_request
131+
await client.models.list(), query_request
132132
)
133-
response, conversation_id = retrieve_response(
133+
response, conversation_id = await retrieve_response(
134134
client,
135135
model_id,
136136
query_request,
@@ -250,19 +250,21 @@ def is_input_shield(shield: Shield) -> bool:
250250
return _is_inout_shield(shield) or not is_output_shield(shield)
251251

252252

253-
def retrieve_response( # pylint: disable=too-many-locals
254-
client: LlamaStackClient,
253+
async def retrieve_response( # pylint: disable=too-many-locals
254+
client: AsyncLlamaStackClient,
255255
model_id: str,
256256
query_request: QueryRequest,
257257
token: str,
258258
mcp_headers: dict[str, dict[str, str]] | None = None,
259259
) -> tuple[str, str]:
260260
"""Retrieve response from LLMs and agents."""
261261
available_input_shields = [
262-
shield.identifier for shield in filter(is_input_shield, client.shields.list())
262+
shield.identifier
263+
for shield in filter(is_input_shield, await client.shields.list())
263264
]
264265
available_output_shields = [
265-
shield.identifier for shield in filter(is_output_shield, client.shields.list())
266+
shield.identifier
267+
for shield in filter(is_output_shield, await client.shields.list())
266268
]
267269
if not available_input_shields and not available_output_shields:
268270
logger.info("No available shields. Disabling safety")
@@ -281,7 +283,7 @@ def retrieve_response( # pylint: disable=too-many-locals
281283
if query_request.attachments:
282284
validate_attachments_metadata(query_request.attachments)
283285

284-
agent, conversation_id, session_id = get_agent(
286+
agent, conversation_id, session_id = await get_agent(
285287
client,
286288
model_id,
287289
system_prompt,
@@ -315,15 +317,17 @@ def retrieve_response( # pylint: disable=too-many-locals
315317
),
316318
}
317319

318-
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
320+
vector_db_ids = [
321+
vector_db.identifier for vector_db in await client.vector_dbs.list()
322+
]
319323
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
320324
mcp_server.name for mcp_server in configuration.mcp_servers
321325
]
322326
# Convert empty list to None for consistency with existing behavior
323327
if not toolgroups:
324328
toolgroups = None
325329

326-
response = agent.create_turn(
330+
response = await agent.create_turn(
327331
messages=[UserMessage(role="user", content=query_request.query)],
328332
session_id=session_id,
329333
documents=query_request.get_documents(),

src/client.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,52 +6,15 @@
66

77
from llama_stack.distribution.library_client import (
88
AsyncLlamaStackAsLibraryClient, # type: ignore
9-
LlamaStackAsLibraryClient, # type: ignore
109
)
11-
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient # type: ignore
10+
from llama_stack_client import AsyncLlamaStackClient # type: ignore
1211
from models.config import LlamaStackConfiguration
1312
from utils.types import Singleton
1413

1514

1615
logger = logging.getLogger(__name__)
1716

1817

19-
class LlamaStackClientHolder(metaclass=Singleton):
20-
"""Container for an initialised LlamaStackClient."""
21-
22-
_lsc: Optional[LlamaStackClient] = None
23-
24-
def load(self, llama_stack_config: LlamaStackConfiguration) -> None:
25-
"""Retrieve Llama stack client according to configuration."""
26-
if llama_stack_config.use_as_library_client is True:
27-
if llama_stack_config.library_client_config_path is not None:
28-
logger.info("Using Llama stack as library client")
29-
client = LlamaStackAsLibraryClient(
30-
llama_stack_config.library_client_config_path
31-
)
32-
client.initialize()
33-
self._lsc = client
34-
else:
35-
msg = "Configuration problem: library_client_config_path option is not set"
36-
logger.error(msg)
37-
# tisnik: use custom exception there - with cause etc.
38-
raise ValueError(msg)
39-
40-
else:
41-
logger.info("Using Llama stack running as a service")
42-
self._lsc = LlamaStackClient(
43-
base_url=llama_stack_config.url, api_key=llama_stack_config.api_key
44-
)
45-
46-
def get_client(self) -> LlamaStackClient:
47-
"""Return an initialised LlamaStackClient."""
48-
if not self._lsc:
49-
raise RuntimeError(
50-
"LlamaStackClient has not been initialised. Ensure 'load(..)' has been called."
51-
)
52-
return self._lsc
53-
54-
5518
class AsyncLlamaStackClientHolder(metaclass=Singleton):
5619
"""Container for an initialised AsyncLlamaStackClient."""
5720

src/lightspeed_stack.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from runners.uvicorn import start_uvicorn
1313
from runners.data_collector import start_data_collector
1414
from configuration import configuration
15-
from client import LlamaStackClientHolder, AsyncLlamaStackClientHolder
15+
from client import AsyncLlamaStackClientHolder
1616

1717
FORMAT = "%(message)s"
1818
logging.basicConfig(
@@ -69,8 +69,6 @@ def main() -> None:
6969
logger.info(
7070
"Llama stack configuration: %s", configuration.llama_stack_configuration
7171
)
72-
logger.info("Creating LlamaStackClient")
73-
LlamaStackClientHolder().load(configuration.configuration.llama_stack)
7472
logger.info("Creating AsyncLlamaStackClient")
7573
asyncio.run(
7674
AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack)

src/metrics/utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Utility functions for metrics handling."""
22

33
from configuration import configuration
4-
from client import LlamaStackClientHolder, AsyncLlamaStackClientHolder
4+
from client import AsyncLlamaStackClientHolder
55
from log import get_logger
66
import metrics
77
from utils.common import run_once_async
@@ -13,11 +13,7 @@
1313
async def setup_model_metrics() -> None:
1414
"""Perform setup of all metrics related to LLM model and provider."""
1515
logger.info("Setting up model metrics")
16-
model_list = []
17-
if configuration.llama_stack_configuration.use_as_library_client:
18-
model_list = await AsyncLlamaStackClientHolder().get_client().models.list()
19-
else:
20-
model_list = LlamaStackClientHolder().get_client().models.list()
16+
model_list = await AsyncLlamaStackClientHolder().get_client().models.list()
2117

2218
models = [
2319
model

src/models/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def check_llama_stack_model(self) -> Self:
9292
if self.library_client_config_path is None:
9393
# pylint: disable=line-too-long
9494
raise ValueError(
95-
"LLama stack library client mode is enabled but a configuration file path is not specified" # noqa: C0301
95+
"LLama stack library client mode is enabled but a configuration file path is not specified" # noqa: E501
9696
)
9797
# the configuration file must exists and be regular readable file
9898
checks.file_check(

0 commit comments

Comments
 (0)