Skip to content

Commit 9de1a02

Browse files
committed
Adding unit test coverage for streaming_query_v2 (cursor generated)
1 parent 9509ae5 commit 9de1a02

File tree

2 files changed

+214
-11
lines changed

2 files changed

+214
-11
lines changed

tests/unit/app/endpoints/test_query_v2.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker):
7070
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
7171

7272
qr = QueryRequest(query="hello", no_tools=True)
73-
summary, conv_id = await retrieve_response(
74-
mock_client, "model-x", qr, token="tkn"
75-
)
73+
summary, conv_id = await retrieve_response(mock_client, "model-x", qr, token="tkn")
7674

7775
assert conv_id == "resp-1"
7876
assert summary.llm_response == ""
@@ -148,9 +146,7 @@ async def test_retrieve_response_parses_output_and_tool_calls(mocker):
148146
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
149147

150148
qr = QueryRequest(query="hello")
151-
summary, conv_id = await retrieve_response(
152-
mock_client, "model-z", qr, token="tkn"
153-
)
149+
summary, conv_id = await retrieve_response(mock_client, "model-z", qr, token="tkn")
154150

155151
assert conv_id == "resp-3"
156152
assert summary.llm_response == "Hello world!"
@@ -181,9 +177,7 @@ async def test_retrieve_response_validates_attachments(mocker):
181177
]
182178

183179
qr = QueryRequest(query="hello", attachments=attachments)
184-
_summary, _cid = await retrieve_response(
185-
mock_client, "model-a", qr, token="tkn"
186-
)
180+
_summary, _cid = await retrieve_response(mock_client, "model-a", qr, token="tkn")
187181

188182
validate_spy.assert_called_once()
189183

@@ -257,5 +251,3 @@ def _raise(*_args, **_kwargs):
257251
assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
258252
assert "Unable to connect to Llama Stack" in str(exc.value.detail)
259253
fail_metric.inc.assert_called_once()
260-
261-
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# pylint: disable=redefined-outer-name, import-error
2+
"""Unit tests for the /streaming_query (v2) endpoint using Responses API."""
3+
4+
from types import SimpleNamespace
5+
import pytest
6+
from fastapi import HTTPException, status, Request
7+
from fastapi.responses import StreamingResponse
8+
9+
from llama_stack_client import APIConnectionError
10+
11+
from models.requests import QueryRequest
12+
from models.config import ModelContextProtocolServer
13+
14+
from app.endpoints.streaming_query_v2 import (
15+
retrieve_response,
16+
streaming_query_endpoint_handler_v2,
17+
)
18+
19+
20+
@pytest.fixture
21+
def dummy_request() -> Request:
22+
req = Request(scope={"type": "http"})
23+
# Provide a permissive authorized_actions set to satisfy RBAC check
24+
from models.config import Action # import here to avoid global import errors
25+
26+
req.state.authorized_actions = set(Action)
27+
return req
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_retrieve_response_builds_rag_and_mcp_tools(mocker):
32+
mock_client = mocker.Mock()
33+
mock_client.vector_dbs.list = mocker.AsyncMock(
34+
return_value=[mocker.Mock(identifier="db1")]
35+
)
36+
mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock())
37+
38+
mocker.patch(
39+
"app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT"
40+
)
41+
42+
mock_cfg = mocker.Mock()
43+
mock_cfg.mcp_servers = [
44+
ModelContextProtocolServer(name="fs", url="http://localhost:3000"),
45+
]
46+
mocker.patch("app.endpoints.streaming_query_v2.configuration", mock_cfg)
47+
48+
qr = QueryRequest(query="hello")
49+
await retrieve_response(mock_client, "model-z", qr, token="tok")
50+
51+
kwargs = mock_client.responses.create.call_args.kwargs
52+
assert kwargs["stream"] is True
53+
tools = kwargs["tools"]
54+
assert isinstance(tools, list)
55+
types = {t.get("type") for t in tools}
56+
assert types == {"file_search", "mcp"}
57+
58+
59+
@pytest.mark.asyncio
60+
async def test_retrieve_response_no_tools_passes_none(mocker):
61+
mock_client = mocker.Mock()
62+
mock_client.vector_dbs.list = mocker.AsyncMock(return_value=[])
63+
mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock())
64+
65+
mocker.patch(
66+
"app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT"
67+
)
68+
mocker.patch(
69+
"app.endpoints.streaming_query_v2.configuration", mocker.Mock(mcp_servers=[])
70+
)
71+
72+
qr = QueryRequest(query="hello", no_tools=True)
73+
await retrieve_response(mock_client, "model-z", qr, token="tok")
74+
75+
kwargs = mock_client.responses.create.call_args.kwargs
76+
assert kwargs["tools"] is None
77+
assert kwargs["stream"] is True
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_streaming_query_endpoint_handler_v2_success_yields_events(
82+
mocker, dummy_request
83+
):
84+
# Skip real config checks
85+
mocker.patch("app.endpoints.streaming_query_v2.check_configuration_loaded")
86+
87+
# Model selection plumbing
88+
mock_client = mocker.Mock()
89+
mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()])
90+
mocker.patch(
91+
"client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client
92+
)
93+
mocker.patch(
94+
"app.endpoints.streaming_query_v2.evaluate_model_hints",
95+
return_value=(None, None),
96+
)
97+
mocker.patch(
98+
"app.endpoints.streaming_query_v2.select_model_and_provider_id",
99+
return_value=("llama/m", "m", "p"),
100+
)
101+
102+
# Replace SSE helpers for deterministic output
103+
mocker.patch(
104+
"app.endpoints.streaming_query_v2.stream_start_event",
105+
lambda conv_id: f"START:{conv_id}\n",
106+
)
107+
mocker.patch(
108+
"app.endpoints.streaming_query_v2.format_stream_data",
109+
lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n",
110+
)
111+
mocker.patch(
112+
"app.endpoints.streaming_query_v2.stream_end_event", lambda _m: "END\n"
113+
)
114+
115+
# Conversation persistence and transcripts disabled
116+
persist_spy = mocker.patch(
117+
"app.endpoints.streaming_query_v2.persist_user_conversation_details",
118+
return_value=None,
119+
)
120+
mocker.patch(
121+
"app.endpoints.streaming_query_v2.is_transcripts_enabled", return_value=False
122+
)
123+
124+
# Build a fake async stream of chunks
125+
async def fake_stream():
126+
yield SimpleNamespace(
127+
type="response.created", response=SimpleNamespace(id="conv-xyz")
128+
)
129+
yield SimpleNamespace(type="response.content_part.added")
130+
yield SimpleNamespace(type="response.output_text.delta", delta="Hello ")
131+
yield SimpleNamespace(type="response.output_text.delta", delta="world")
132+
yield SimpleNamespace(
133+
type="response.output_item.added",
134+
item=SimpleNamespace(
135+
type="function_call", id="item1", name="search", call_id="call1"
136+
),
137+
)
138+
yield SimpleNamespace(
139+
type="response.function_call_arguments.delta", delta='{"q":"x"}'
140+
)
141+
yield SimpleNamespace(
142+
type="response.function_call_arguments.done",
143+
item_id="item1",
144+
arguments='{"q":"x"}',
145+
)
146+
yield SimpleNamespace(type="response.output_text.done", text="Hello world")
147+
yield SimpleNamespace(type="response.completed")
148+
149+
mocker.patch(
150+
"app.endpoints.streaming_query_v2.retrieve_response",
151+
return_value=(fake_stream(), ""),
152+
)
153+
154+
metric = mocker.patch("metrics.llm_calls_total")
155+
156+
resp = await streaming_query_endpoint_handler_v2(
157+
request=dummy_request,
158+
query_request=QueryRequest(query="hi"),
159+
auth=("user123", "", False, "token-abc"),
160+
mcp_headers={},
161+
)
162+
163+
assert isinstance(resp, StreamingResponse)
164+
metric.labels("p", "m").inc.assert_called_once()
165+
166+
# Collect emitted events
167+
events: list[str] = []
168+
async for chunk in resp.body_iterator:
169+
s = chunk.decode() if isinstance(chunk, (bytes, bytearray)) else str(chunk)
170+
events.append(s)
171+
172+
# Validate event sequence and content
173+
assert events[0] == "START:conv-xyz\n"
174+
# content_part.added triggers empty token
175+
assert events[1] == "EV:token:\n"
176+
assert events[2] == "EV:token:Hello \n"
177+
assert events[3] == "EV:token:world\n"
178+
# tool call delta
179+
assert events[4].startswith("EV:tool_call:")
180+
# turn complete and end
181+
assert "EV:turn_complete:Hello world\n" in events
182+
assert events[-1] == "END\n"
183+
184+
# Verify conversation persistence was invoked with the created id
185+
persist_spy.assert_called_once()
186+
187+
188+
@pytest.mark.asyncio
189+
async def test_streaming_query_endpoint_handler_v2_api_connection_error(
190+
mocker, dummy_request
191+
):
192+
mocker.patch("app.endpoints.streaming_query_v2.check_configuration_loaded")
193+
194+
def _raise(*_a, **_k):
195+
raise APIConnectionError(request=None)
196+
197+
mocker.patch("client.AsyncLlamaStackClientHolder.get_client", side_effect=_raise)
198+
199+
fail_metric = mocker.patch("metrics.llm_calls_failures_total")
200+
201+
with pytest.raises(HTTPException) as exc:
202+
await streaming_query_endpoint_handler_v2(
203+
request=dummy_request,
204+
query_request=QueryRequest(query="hi"),
205+
auth=("user123", "", False, "tok"),
206+
mcp_headers={},
207+
)
208+
209+
assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
210+
assert "Unable to connect to Llama Stack" in str(exc.value.detail)
211+
fail_metric.inc.assert_called_once()

0 commit comments

Comments
 (0)