Skip to content

Commit 4bbda24

Browse files
authored
Merge pull request #287 from tisnik/lcore-381-new-unit-tests
LCORE-381: new unit tests for client.py and models endpoints
2 parents 6459aac + 6494660 commit 4bbda24

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

tests/unit/app/endpoints/test_models.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Unit tests for the /models REST API endpoint."""
22

3-
from unittest.mock import Mock
4-
53
import pytest
64

75
from fastapi import HTTPException, Request, status
86

7+
from llama_stack_client import APIConnectionError
8+
99
from app.endpoints.models import models_endpoint_handler
1010
from configuration import AppConfig
1111

@@ -142,7 +142,7 @@ def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker):
142142
cfg.init_from_dict(config_dict)
143143

144144
# Mock the LlamaStack client
145-
mock_client = Mock()
145+
mock_client = mocker.Mock()
146146
mock_client.models.list.return_value = []
147147
mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client")
148148
mock_lsc.return_value = mock_client
@@ -157,3 +157,50 @@ def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker):
157157
)
158158
response = models_endpoint_handler(request)
159159
assert response is not None
160+
161+
162+
def test_models_endpoint_llama_stack_connection_error(mocker):
163+
"""Test the model endpoint when LlamaStack connection fails."""
164+
# configuration for tests
165+
config_dict = {
166+
"name": "foo",
167+
"service": {
168+
"host": "localhost",
169+
"port": 8080,
170+
"auth_enabled": False,
171+
"workers": 1,
172+
"color_log": True,
173+
"access_log": True,
174+
},
175+
"llama_stack": {
176+
"api_key": "xyzzy",
177+
"url": "http://x.y.com:1234",
178+
"use_as_library_client": False,
179+
},
180+
"user_data_collection": {
181+
"feedback_enabled": False,
182+
},
183+
"customization": None,
184+
}
185+
186+
# mock LlamaStackClientHolder to raise APIConnectionError
187+
# when models.list() method is called
188+
mock_client = mocker.Mock()
189+
mock_client.models.list.side_effect = APIConnectionError(request=None)
190+
mock_client_holder = mocker.patch("app.endpoints.models.LlamaStackClientHolder")
191+
mock_client_holder.return_value.get_client.return_value = mock_client
192+
193+
cfg = AppConfig()
194+
cfg.init_from_dict(config_dict)
195+
196+
request = Request(
197+
scope={
198+
"type": "http",
199+
"headers": [(b"authorization", b"Bearer invalid-token")],
200+
}
201+
)
202+
203+
with pytest.raises(HTTPException) as e:
204+
models_endpoint_handler(request)
205+
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
206+
assert e.detail["response"] == "Unable to connect to Llama Stack"

tests/unit/test_client.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,32 @@
66
from models.config import LlamaStackConfiguration
77

88

9+
def test_client_get_client_method() -> None:
10+
"""Test how get_client method works for unitialized client."""
11+
12+
client = LlamaStackClientHolder()
13+
14+
with pytest.raises(
15+
RuntimeError,
16+
match="LlamaStackClient has not been initialised. Ensure 'load\\(..\\)' has been called.",
17+
):
18+
client.get_client()
19+
20+
21+
def test_async_client_get_client_method() -> None:
22+
"""Test how get_client method works for unitialized client."""
23+
client = AsyncLlamaStackClientHolder()
24+
25+
with pytest.raises(
26+
RuntimeError,
27+
match=(
28+
"AsyncLlamaStackClient has not been initialised. "
29+
"Ensure 'load\\(..\\)' has been called."
30+
),
31+
):
32+
client.get_client()
33+
34+
935
def test_get_llama_stack_library_client() -> None:
1036
"""Test if Llama Stack can be initialized in library client mode."""
1137
cfg = LlamaStackConfiguration(
@@ -18,6 +44,9 @@ def test_get_llama_stack_library_client() -> None:
1844
client.load(cfg)
1945
assert client is not None
2046

47+
ls_client = client.get_client()
48+
assert ls_client is not None
49+
2150

2251
def test_get_llama_stack_remote_client() -> None:
2352
"""Test if Llama Stack can be initialized in remove client (server) mode."""
@@ -31,6 +60,9 @@ def test_get_llama_stack_remote_client() -> None:
3160
client.load(cfg)
3261
assert client is not None
3362

63+
ls_client = client.get_client()
64+
assert ls_client is not None
65+
3466

3567
def test_get_llama_stack_wrong_configuration() -> None:
3668
"""Test if configuration is checked before Llama Stack is initialized."""
@@ -61,6 +93,9 @@ async def test_get_async_llama_stack_library_client() -> None:
6193
await client.load(cfg)
6294
assert client is not None
6395

96+
ls_client = client.get_client()
97+
assert ls_client is not None
98+
6499

65100
async def test_get_async_llama_stack_remote_client() -> None:
66101
"""Test the initialization of asynchronous Llama Stack client in server mode."""
@@ -74,6 +109,9 @@ async def test_get_async_llama_stack_remote_client() -> None:
74109
await client.load(cfg)
75110
assert client is not None
76111

112+
ls_client = client.get_client()
113+
assert ls_client is not None
114+
77115

78116
async def test_get_async_llama_stack_wrong_configuration() -> None:
79117
"""Test if configuration is checked before Llama Stack is initialized."""

0 commit comments

Comments
 (0)