2929from metrics .utils import update_llm_token_count_from_turn
3030from models .config import Action
3131from models .requests import QueryRequest
32+ from models .responses import StreamedChunk
3233from models .database .conversations import UserConversation
3334from utils .endpoints import check_configuration_loaded , get_agent , get_system_prompt
3435from utils .mcp_headers import mcp_headers_dependency , handle_mcp_headers_with_toolgroups
4849 evaluate_model_hints ,
4950)
5051
52+ # Constants for OLS-compatible event types
53+ LLM_TOKEN_EVENT = "token"
54+ LLM_TOOL_CALL_EVENT = "tool_call"
55+ LLM_TOOL_RESULT_EVENT = "tool_result"
56+
5157logger = logging .getLogger ("app.endpoints.handlers" )
5258router = APIRouter (tags = ["streaming_query" ])
5359auth_dependency = get_auth_dependency ()
@@ -94,41 +100,73 @@ def stream_start_event(conversation_id: str) -> str:
94100 )
95101
96102
97- def stream_end_event (metadata_map : dict ) -> str :
103+ def stream_end_event (
104+ ref_docs : list [dict ],
105+ truncated : bool ,
106+ media_type : str ,
107+ token_counter : dict | None = None ,
108+ available_quotas : dict [str , int ] | None = None ,
109+ ) -> str :
98110 """
99111 Yield the end of the data stream.
100112
101113 Format and return the end event for a streaming response,
102- including referenced document metadata and placeholder token
103- counts.
114+ including referenced document metadata and token counts.
104115
105116 Parameters:
106- metadata_map (dict): A mapping containing metadata about
107- referenced documents.
117+ ref_docs: Referenced documents.
118+ truncated: Indicates if the history was truncated.
119+ media_type: Media type of the response (e.g. text or JSON).
120+ token_counter: Token counter for the whole stream.
121+ available_quotas: Quotas available for configured quota limiters.
108122
109123 Returns:
110124 str: A Server-Sent Events (SSE) formatted string
111125 representing the end of the data stream.
112126 """
127+ if media_type == "application/json" :
128+ return format_stream_data (
129+ {
130+ "event" : "end" ,
131+ "data" : {
132+ "referenced_documents" : ref_docs ,
133+ "truncated" : truncated ,
134+ "input_tokens" : token_counter .get ("input_tokens" , 0 ) if token_counter else 0 ,
135+ "output_tokens" : token_counter .get ("output_tokens" , 0 ) if token_counter else 0 ,
136+ },
137+ "available_quotas" : available_quotas or {},
138+ }
139+ )
140+ ref_docs_string = "\n " .join (
141+ f'{ item ["doc_title" ]} : { item ["doc_url" ]} ' for item in ref_docs
142+ )
143+ return f"\n \n ---\n \n { ref_docs_string } " if ref_docs_string else ""
144+
145+
146+ def stream_event (data : dict , event_type : str , media_type : str ) -> str :
147+ """Build an item to yield based on media type.
148+
149+ Args:
150+ data: The data to yield.
151+ event_type: The type of event (e.g. token, tool request, tool execution).
152+ media_type: Media type of the response (e.g. text or JSON).
153+
154+ Returns:
155+ str: The formatted string or JSON to yield.
156+ """
157+ if media_type == "text/plain" :
158+ if event_type == LLM_TOKEN_EVENT :
159+ return data ["token" ]
160+ if event_type == LLM_TOOL_CALL_EVENT :
161+ return f"\n Tool call: { json .dumps (data )} \n "
162+ if event_type == LLM_TOOL_RESULT_EVENT :
163+ return f"\n Tool result: { json .dumps (data )} \n "
164+ logger .error ("Unknown event type: %s" , event_type )
165+ return ""
113166 return format_stream_data (
114167 {
115- "event" : "end" ,
116- "data" : {
117- "referenced_documents" : [
118- {
119- "doc_url" : v ["docs_url" ],
120- "doc_title" : v ["title" ],
121- }
122- for v in filter (
123- lambda v : ("docs_url" in v ) and ("title" in v ),
124- metadata_map .values (),
125- )
126- ],
127- "truncated" : None , # TODO(jboos): implement truncated
128- "input_tokens" : 0 , # TODO(jboos): implement input tokens
129- "output_tokens" : 0 , # TODO(jboos): implement output tokens
130- },
131- "available_quotas" : {}, # TODO(jboos): implement available quotas
168+ "event" : event_type ,
169+ "data" : data ,
132170 }
133171 )
134172
@@ -203,6 +241,35 @@ def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]:
203241 )
204242
205243
244+ def generic_llm_error (error : Exception , media_type : str ) -> str :
245+ """Return error representation for generic LLM errors.
246+
247+ Args:
248+ error: The exception raised during processing.
249+ media_type: Media type of the response (e.g. text or JSON).
250+
251+ Returns:
252+ str: The error message formatted for the media type.
253+ """
254+ logger .error ("Error while obtaining answer for user question" )
255+ logger .exception (error )
256+
257+ response = "Error while obtaining answer for user question"
258+ cause = str (error )
259+
260+ if media_type == "text/plain" :
261+ return f"{ response } : { cause } "
262+ return format_stream_data (
263+ {
264+ "event" : "error" ,
265+ "data" : {
266+ "response" : response ,
267+ "cause" : cause ,
268+ },
269+ }
270+ )
271+
272+
206273# -----------------------------------
207274# Turn handling
208275# -----------------------------------
@@ -223,7 +290,7 @@ def _handle_turn_start_event(chunk_id: int) -> Iterator[str]:
223290 """
224291 yield format_stream_data (
225292 {
226- "event" : "token" ,
293+ "event" : LLM_TOKEN_EVENT ,
227294 "data" : {
228295 "id" : chunk_id ,
229296 "token" : "" ,
@@ -322,10 +389,9 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
322389 if chunk .event .payload .event_type == "step_start" :
323390 yield format_stream_data (
324391 {
325- "event" : "token" ,
392+ "event" : LLM_TOKEN_EVENT ,
326393 "data" : {
327394 "id" : chunk_id ,
328- "role" : chunk .event .payload .step_type ,
329395 "token" : "" ,
330396 },
331397 }
@@ -336,21 +402,19 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
336402 if isinstance (chunk .event .payload .delta .tool_call , str ):
337403 yield format_stream_data (
338404 {
339- "event" : "tool_call" ,
405+ "event" : LLM_TOOL_CALL_EVENT ,
340406 "data" : {
341407 "id" : chunk_id ,
342- "role" : chunk .event .payload .step_type ,
343408 "token" : chunk .event .payload .delta .tool_call ,
344409 },
345410 }
346411 )
347412 elif isinstance (chunk .event .payload .delta .tool_call , ToolCall ):
348413 yield format_stream_data (
349414 {
350- "event" : "tool_call" ,
415+ "event" : LLM_TOOL_CALL_EVENT ,
351416 "data" : {
352417 "id" : chunk_id ,
353- "role" : chunk .event .payload .step_type ,
354418 "token" : chunk .event .payload .delta .tool_call .tool_name ,
355419 },
356420 }
@@ -359,10 +423,9 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
359423 elif chunk .event .payload .delta .type == "text" :
360424 yield format_stream_data (
361425 {
362- "event" : "token" ,
426+ "event" : LLM_TOKEN_EVENT ,
363427 "data" : {
364428 "id" : chunk_id ,
365- "role" : chunk .event .payload .step_type ,
366429 "token" : chunk .event .payload .delta .text ,
367430 },
368431 }
@@ -377,7 +440,7 @@ def _handle_tool_execution_event(
377440 chunk : Any , chunk_id : int , metadata_map : dict
378441) -> Iterator [str ]:
379442 """
380- Yield tool call event .
443+ Yield tool call and tool result events .
381444
382445 Processes tool execution events from a streaming chunk and
383446 yields formatted Server-Sent Events (SSE) strings.
@@ -399,23 +462,22 @@ def _handle_tool_execution_event(
399462 if chunk .event .payload .event_type == "step_start" :
400463 yield format_stream_data (
401464 {
402- "event" : "tool_call" ,
465+ "event" : LLM_TOOL_CALL_EVENT ,
403466 "data" : {
404467 "id" : chunk_id ,
405- "role" : chunk .event .payload .step_type ,
406468 "token" : "" ,
407469 },
408470 }
409471 )
410472
411473 elif chunk .event .payload .event_type == "step_complete" :
474+ # First yield tool calls
412475 for t in chunk .event .payload .step_details .tool_calls :
413476 yield format_stream_data (
414477 {
415- "event" : "tool_call" ,
478+ "event" : LLM_TOOL_CALL_EVENT ,
416479 "data" : {
417480 "id" : chunk_id ,
418- "role" : chunk .event .payload .step_type ,
419481 "token" : {
420482 "tool_name" : t .tool_name ,
421483 "arguments" : t .arguments ,
@@ -424,15 +486,15 @@ def _handle_tool_execution_event(
424486 }
425487 )
426488
489+ # Then yield tool results
427490 for r in chunk .event .payload .step_details .tool_responses :
428491 if r .tool_name == "query_from_memory" :
429492 inserted_context = interleaved_content_as_str (r .content )
430493 yield format_stream_data (
431494 {
432- "event" : "tool_call" ,
495+ "event" : LLM_TOOL_RESULT_EVENT ,
433496 "data" : {
434497 "id" : chunk_id ,
435- "role" : chunk .event .payload .step_type ,
436498 "token" : {
437499 "tool_name" : r .tool_name ,
438500 "response" : f"Fetched { len (inserted_context )} bytes from memory" ,
@@ -463,10 +525,9 @@ def _handle_tool_execution_event(
463525
464526 yield format_stream_data (
465527 {
466- "event" : "tool_call" ,
528+ "event" : LLM_TOOL_RESULT_EVENT ,
467529 "data" : {
468530 "id" : chunk_id ,
469- "role" : chunk .event .payload .step_type ,
470531 "token" : {
471532 "tool_name" : r .tool_name ,
472533 "summary" : summary ,
@@ -478,10 +539,9 @@ def _handle_tool_execution_event(
478539 else :
479540 yield format_stream_data (
480541 {
481- "event" : "tool_call" ,
542+ "event" : LLM_TOOL_RESULT_EVENT ,
482543 "data" : {
483544 "id" : chunk_id ,
484- "role" : chunk .event .payload .step_type ,
485545 "token" : {
486546 "tool_name" : r .tool_name ,
487547 "response" : interleaved_content_as_str (r .content ),
@@ -614,32 +674,61 @@ async def response_generator(
614674 summary = TurnSummary (
615675 llm_response = "No response from the model" , tool_calls = []
616676 )
617-
677+
678+ # Determine media type for OLS compatibility
679+ media_type = query_request .media_type or "application/json"
680+
618681 # Send start event
619682 yield stream_start_event (conversation_id )
620683
621- async for chunk in turn_response :
622- p = chunk .event .payload
623- if p .event_type == "turn_complete" :
624- summary .llm_response = interleaved_content_as_str (
625- p .turn .output_message .content
626- )
627- system_prompt = get_system_prompt (query_request , configuration )
628- try :
629- update_llm_token_count_from_turn (
630- p .turn , model_id , provider_id , system_prompt
684+ # Track referenced documents and token counts
685+ ref_docs = []
686+ token_counter = {"input_tokens" : 0 , "output_tokens" : 0 }
687+ truncated = False
688+
689+ try :
690+ async for chunk in turn_response :
691+ p = chunk .event .payload
692+ if p .event_type == "turn_complete" :
693+ summary .llm_response = interleaved_content_as_str (
694+ p .turn .output_message .content
631695 )
632- except Exception : # pylint: disable=broad-except
633- logger .exception ("Failed to update token usage metrics" )
634- elif p .event_type == "step_complete" :
635- if p .step_details .step_type == "tool_execution" :
636- summary .append_tool_calls_from_llama (p .step_details )
637-
638- for event in stream_build_event (chunk , chunk_id , metadata_map ):
639- chunk_id += 1
640- yield event
696+ system_prompt = get_system_prompt (query_request , configuration )
697+ try :
698+ update_llm_token_count_from_turn (
699+ p .turn , model_id , provider_id , system_prompt
700+ )
701+ # Extract token counts from the turn if available
702+ if hasattr (p .turn , 'usage' ) and p .turn .usage :
703+ token_counter ["input_tokens" ] = getattr (p .turn .usage , 'prompt_tokens' , 0 )
704+ token_counter ["output_tokens" ] = getattr (p .turn .usage , 'completion_tokens' , 0 )
705+ except Exception : # pylint: disable=broad-except
706+ logger .exception ("Failed to update token usage metrics" )
707+ elif p .event_type == "step_complete" :
708+ if p .step_details .step_type == "tool_execution" :
709+ summary .append_tool_calls_from_llama (p .step_details )
710+
711+ for event in stream_build_event (chunk , chunk_id , metadata_map ):
712+ chunk_id += 1
713+ yield event
714+ except Exception as e :
715+ # Handle streaming errors
716+ yield generic_llm_error (e , media_type )
717+ return
718+
719+ # Build referenced documents from metadata
720+ ref_docs = [
721+ {
722+ "doc_url" : v ["docs_url" ],
723+ "doc_title" : v ["title" ],
724+ }
725+ for v in filter (
726+ lambda v : ("docs_url" in v ) and ("title" in v ),
727+ metadata_map .values (),
728+ )
729+ ]
641730
642- yield stream_end_event (metadata_map )
731+ yield stream_end_event (ref_docs , truncated , media_type , token_counter , {} )
643732
644733 if not is_transcripts_enabled ():
645734 logger .debug ("Transcript collection is disabled in the configuration" )
@@ -669,7 +758,12 @@ async def response_generator(
669758 # Update metrics for the LLM call
670759 metrics .llm_calls_total .labels (provider_id , model_id ).inc ()
671760
672- return StreamingResponse (response_generator (response ))
761+ # Determine media type for OLS compatibility
762+ media_type = query_request .media_type or "application/json"
763+ return StreamingResponse (
764+ response_generator (response ),
765+ media_type = media_type
766+ )
673767 # connection to Llama Stack server
674768 except APIConnectionError as e :
675769 # Update metrics for the LLM call failure
0 commit comments