Skip to content

Commit cfa516a

Browse files
committed
(feat) ensure compatibility with OLS streaming query
1 parent 528e0b9 commit cfa516a

File tree

3 files changed

+183
-68
lines changed

3 files changed

+183
-68
lines changed

src/app/endpoints/streaming_query.py

Lines changed: 158 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from metrics.utils import update_llm_token_count_from_turn
3030
from models.config import Action
3131
from models.requests import QueryRequest
32+
from models.responses import StreamedChunk
3233
from models.database.conversations import UserConversation
3334
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
3435
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
@@ -48,6 +49,11 @@
4849
evaluate_model_hints,
4950
)
5051

52+
# Constants for OLS-compatible event types
53+
LLM_TOKEN_EVENT = "token"
54+
LLM_TOOL_CALL_EVENT = "tool_call"
55+
LLM_TOOL_RESULT_EVENT = "tool_result"
56+
5157
logger = logging.getLogger("app.endpoints.handlers")
5258
router = APIRouter(tags=["streaming_query"])
5359
auth_dependency = get_auth_dependency()
@@ -94,41 +100,73 @@ def stream_start_event(conversation_id: str) -> str:
94100
)
95101

96102

97-
def stream_end_event(metadata_map: dict) -> str:
103+
def stream_end_event(
104+
ref_docs: list[dict],
105+
truncated: bool,
106+
media_type: str,
107+
token_counter: dict | None = None,
108+
available_quotas: dict[str, int] | None = None,
109+
) -> str:
98110
"""
99111
Yield the end of the data stream.
100112
101113
Format and return the end event for a streaming response,
102-
including referenced document metadata and placeholder token
103-
counts.
114+
including referenced document metadata and token counts.
104115
105116
Parameters:
106-
metadata_map (dict): A mapping containing metadata about
107-
referenced documents.
117+
ref_docs: Referenced documents.
118+
truncated: Indicates if the history was truncated.
119+
media_type: Media type of the response (e.g. text or JSON).
120+
token_counter: Token counter for the whole stream.
121+
available_quotas: Quotas available for configured quota limiters.
108122
109123
Returns:
110124
str: A Server-Sent Events (SSE) formatted string
111125
representing the end of the data stream.
112126
"""
127+
if media_type == "application/json":
128+
return format_stream_data(
129+
{
130+
"event": "end",
131+
"data": {
132+
"referenced_documents": ref_docs,
133+
"truncated": truncated,
134+
"input_tokens": token_counter.get("input_tokens", 0) if token_counter else 0,
135+
"output_tokens": token_counter.get("output_tokens", 0) if token_counter else 0,
136+
},
137+
"available_quotas": available_quotas or {},
138+
}
139+
)
140+
ref_docs_string = "\n".join(
141+
f'{item["doc_title"]}: {item["doc_url"]}' for item in ref_docs
142+
)
143+
return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else ""
144+
145+
146+
def stream_event(data: dict, event_type: str, media_type: str) -> str:
147+
"""Build an item to yield based on media type.
148+
149+
Args:
150+
data: The data to yield.
151+
event_type: The type of event (e.g. token, tool request, tool execution).
152+
media_type: Media type of the response (e.g. text or JSON).
153+
154+
Returns:
155+
str: The formatted string or JSON to yield.
156+
"""
157+
if media_type == "text/plain":
158+
if event_type == LLM_TOKEN_EVENT:
159+
return data["token"]
160+
if event_type == LLM_TOOL_CALL_EVENT:
161+
return f"\nTool call: {json.dumps(data)}\n"
162+
if event_type == LLM_TOOL_RESULT_EVENT:
163+
return f"\nTool result: {json.dumps(data)}\n"
164+
logger.error("Unknown event type: %s", event_type)
165+
return ""
113166
return format_stream_data(
114167
{
115-
"event": "end",
116-
"data": {
117-
"referenced_documents": [
118-
{
119-
"doc_url": v["docs_url"],
120-
"doc_title": v["title"],
121-
}
122-
for v in filter(
123-
lambda v: ("docs_url" in v) and ("title" in v),
124-
metadata_map.values(),
125-
)
126-
],
127-
"truncated": None, # TODO(jboos): implement truncated
128-
"input_tokens": 0, # TODO(jboos): implement input tokens
129-
"output_tokens": 0, # TODO(jboos): implement output tokens
130-
},
131-
"available_quotas": {}, # TODO(jboos): implement available quotas
168+
"event": event_type,
169+
"data": data,
132170
}
133171
)
134172

@@ -203,6 +241,35 @@ def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]:
203241
)
204242

205243

244+
def generic_llm_error(error: Exception, media_type: str) -> str:
245+
"""Return error representation for generic LLM errors.
246+
247+
Args:
248+
error: The exception raised during processing.
249+
media_type: Media type of the response (e.g. text or JSON).
250+
251+
Returns:
252+
str: The error message formatted for the media type.
253+
"""
254+
logger.error("Error while obtaining answer for user question")
255+
logger.exception(error)
256+
257+
response = "Error while obtaining answer for user question"
258+
cause = str(error)
259+
260+
if media_type == "text/plain":
261+
return f"{response}: {cause}"
262+
return format_stream_data(
263+
{
264+
"event": "error",
265+
"data": {
266+
"response": response,
267+
"cause": cause,
268+
},
269+
}
270+
)
271+
272+
206273
# -----------------------------------
207274
# Turn handling
208275
# -----------------------------------
@@ -223,7 +290,7 @@ def _handle_turn_start_event(chunk_id: int) -> Iterator[str]:
223290
"""
224291
yield format_stream_data(
225292
{
226-
"event": "token",
293+
"event": LLM_TOKEN_EVENT,
227294
"data": {
228295
"id": chunk_id,
229296
"token": "",
@@ -322,10 +389,9 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
322389
if chunk.event.payload.event_type == "step_start":
323390
yield format_stream_data(
324391
{
325-
"event": "token",
392+
"event": LLM_TOKEN_EVENT,
326393
"data": {
327394
"id": chunk_id,
328-
"role": chunk.event.payload.step_type,
329395
"token": "",
330396
},
331397
}
@@ -336,21 +402,19 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
336402
if isinstance(chunk.event.payload.delta.tool_call, str):
337403
yield format_stream_data(
338404
{
339-
"event": "tool_call",
405+
"event": LLM_TOOL_CALL_EVENT,
340406
"data": {
341407
"id": chunk_id,
342-
"role": chunk.event.payload.step_type,
343408
"token": chunk.event.payload.delta.tool_call,
344409
},
345410
}
346411
)
347412
elif isinstance(chunk.event.payload.delta.tool_call, ToolCall):
348413
yield format_stream_data(
349414
{
350-
"event": "tool_call",
415+
"event": LLM_TOOL_CALL_EVENT,
351416
"data": {
352417
"id": chunk_id,
353-
"role": chunk.event.payload.step_type,
354418
"token": chunk.event.payload.delta.tool_call.tool_name,
355419
},
356420
}
@@ -359,10 +423,9 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
359423
elif chunk.event.payload.delta.type == "text":
360424
yield format_stream_data(
361425
{
362-
"event": "token",
426+
"event": LLM_TOKEN_EVENT,
363427
"data": {
364428
"id": chunk_id,
365-
"role": chunk.event.payload.step_type,
366429
"token": chunk.event.payload.delta.text,
367430
},
368431
}
@@ -377,7 +440,7 @@ def _handle_tool_execution_event(
377440
chunk: Any, chunk_id: int, metadata_map: dict
378441
) -> Iterator[str]:
379442
"""
380-
Yield tool call event.
443+
Yield tool call and tool result events.
381444
382445
Processes tool execution events from a streaming chunk and
383446
yields formatted Server-Sent Events (SSE) strings.
@@ -399,23 +462,22 @@ def _handle_tool_execution_event(
399462
if chunk.event.payload.event_type == "step_start":
400463
yield format_stream_data(
401464
{
402-
"event": "tool_call",
465+
"event": LLM_TOOL_CALL_EVENT,
403466
"data": {
404467
"id": chunk_id,
405-
"role": chunk.event.payload.step_type,
406468
"token": "",
407469
},
408470
}
409471
)
410472

411473
elif chunk.event.payload.event_type == "step_complete":
474+
# First yield tool calls
412475
for t in chunk.event.payload.step_details.tool_calls:
413476
yield format_stream_data(
414477
{
415-
"event": "tool_call",
478+
"event": LLM_TOOL_CALL_EVENT,
416479
"data": {
417480
"id": chunk_id,
418-
"role": chunk.event.payload.step_type,
419481
"token": {
420482
"tool_name": t.tool_name,
421483
"arguments": t.arguments,
@@ -424,15 +486,15 @@ def _handle_tool_execution_event(
424486
}
425487
)
426488

489+
# Then yield tool results
427490
for r in chunk.event.payload.step_details.tool_responses:
428491
if r.tool_name == "query_from_memory":
429492
inserted_context = interleaved_content_as_str(r.content)
430493
yield format_stream_data(
431494
{
432-
"event": "tool_call",
495+
"event": LLM_TOOL_RESULT_EVENT,
433496
"data": {
434497
"id": chunk_id,
435-
"role": chunk.event.payload.step_type,
436498
"token": {
437499
"tool_name": r.tool_name,
438500
"response": f"Fetched {len(inserted_context)} bytes from memory",
@@ -463,10 +525,9 @@ def _handle_tool_execution_event(
463525

464526
yield format_stream_data(
465527
{
466-
"event": "tool_call",
528+
"event": LLM_TOOL_RESULT_EVENT,
467529
"data": {
468530
"id": chunk_id,
469-
"role": chunk.event.payload.step_type,
470531
"token": {
471532
"tool_name": r.tool_name,
472533
"summary": summary,
@@ -478,10 +539,9 @@ def _handle_tool_execution_event(
478539
else:
479540
yield format_stream_data(
480541
{
481-
"event": "tool_call",
542+
"event": LLM_TOOL_RESULT_EVENT,
482543
"data": {
483544
"id": chunk_id,
484-
"role": chunk.event.payload.step_type,
485545
"token": {
486546
"tool_name": r.tool_name,
487547
"response": interleaved_content_as_str(r.content),
@@ -614,32 +674,61 @@ async def response_generator(
614674
summary = TurnSummary(
615675
llm_response="No response from the model", tool_calls=[]
616676
)
617-
677+
678+
# Determine media type for OLS compatibility
679+
media_type = query_request.media_type or "application/json"
680+
618681
# Send start event
619682
yield stream_start_event(conversation_id)
620683

621-
async for chunk in turn_response:
622-
p = chunk.event.payload
623-
if p.event_type == "turn_complete":
624-
summary.llm_response = interleaved_content_as_str(
625-
p.turn.output_message.content
626-
)
627-
system_prompt = get_system_prompt(query_request, configuration)
628-
try:
629-
update_llm_token_count_from_turn(
630-
p.turn, model_id, provider_id, system_prompt
684+
# Track referenced documents and token counts
685+
ref_docs = []
686+
token_counter = {"input_tokens": 0, "output_tokens": 0}
687+
truncated = False
688+
689+
try:
690+
async for chunk in turn_response:
691+
p = chunk.event.payload
692+
if p.event_type == "turn_complete":
693+
summary.llm_response = interleaved_content_as_str(
694+
p.turn.output_message.content
631695
)
632-
except Exception: # pylint: disable=broad-except
633-
logger.exception("Failed to update token usage metrics")
634-
elif p.event_type == "step_complete":
635-
if p.step_details.step_type == "tool_execution":
636-
summary.append_tool_calls_from_llama(p.step_details)
637-
638-
for event in stream_build_event(chunk, chunk_id, metadata_map):
639-
chunk_id += 1
640-
yield event
696+
system_prompt = get_system_prompt(query_request, configuration)
697+
try:
698+
update_llm_token_count_from_turn(
699+
p.turn, model_id, provider_id, system_prompt
700+
)
701+
# Extract token counts from the turn if available
702+
if hasattr(p.turn, 'usage') and p.turn.usage:
703+
token_counter["input_tokens"] = getattr(p.turn.usage, 'prompt_tokens', 0)
704+
token_counter["output_tokens"] = getattr(p.turn.usage, 'completion_tokens', 0)
705+
except Exception: # pylint: disable=broad-except
706+
logger.exception("Failed to update token usage metrics")
707+
elif p.event_type == "step_complete":
708+
if p.step_details.step_type == "tool_execution":
709+
summary.append_tool_calls_from_llama(p.step_details)
710+
711+
for event in stream_build_event(chunk, chunk_id, metadata_map):
712+
chunk_id += 1
713+
yield event
714+
except Exception as e:
715+
# Handle streaming errors
716+
yield generic_llm_error(e, media_type)
717+
return
718+
719+
# Build referenced documents from metadata
720+
ref_docs = [
721+
{
722+
"doc_url": v["docs_url"],
723+
"doc_title": v["title"],
724+
}
725+
for v in filter(
726+
lambda v: ("docs_url" in v) and ("title" in v),
727+
metadata_map.values(),
728+
)
729+
]
641730

642-
yield stream_end_event(metadata_map)
731+
yield stream_end_event(ref_docs, truncated, media_type, token_counter, {})
643732

644733
if not is_transcripts_enabled():
645734
logger.debug("Transcript collection is disabled in the configuration")
@@ -669,7 +758,12 @@ async def response_generator(
669758
# Update metrics for the LLM call
670759
metrics.llm_calls_total.labels(provider_id, model_id).inc()
671760

672-
return StreamingResponse(response_generator(response))
761+
# Determine media type for OLS compatibility
762+
media_type = query_request.media_type or "application/json"
763+
return StreamingResponse(
764+
response_generator(response),
765+
media_type=media_type
766+
)
673767
# connection to Llama Stack server
674768
except APIConnectionError as e:
675769
# Update metrics for the LLM call failure

0 commit comments

Comments
 (0)