diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d27388e2..bdf49b98 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -14,7 +14,6 @@ from llama_stack_client.types import UserMessage # type: ignore from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str -from llama_stack_client.types.shared import ToolCall from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from fastapi import APIRouter, HTTPException, Request, Depends, status @@ -256,55 +255,21 @@ def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]: # Inference handling # ----------------------------------- def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]: - if chunk.event.payload.event_type == "step_start": + if ( + chunk.event.payload.event_type == "step_progress" + and chunk.event.payload.delta.type == "text" + ): yield format_stream_data( { "event": "token", "data": { "id": chunk_id, "role": chunk.event.payload.step_type, - "token": "", + "token": chunk.event.payload.delta.text, }, } ) - elif chunk.event.payload.event_type == "step_progress": - if chunk.event.payload.delta.type == "tool_call": - if isinstance(chunk.event.payload.delta.tool_call, str): - yield format_stream_data( - { - "event": "tool_call", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": chunk.event.payload.delta.tool_call, - }, - } - ) - elif isinstance(chunk.event.payload.delta.tool_call, ToolCall): - yield format_stream_data( - { - "event": "tool_call", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": chunk.event.payload.delta.tool_call.tool_name, - }, - } - ) - - elif chunk.event.payload.delta.type == "text": - yield format_stream_data( - { - "event": "token", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": chunk.event.payload.delta.text, - }, - } - ) - # ----------------------------------- # Tool Execution handling @@ -313,19 +278,7 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]: def _handle_tool_execution_event( chunk: Any, chunk_id: int, metadata_map: dict ) -> Iterator[str]: - if chunk.event.payload.event_type == "step_start": - yield format_stream_data( - { - "event": "tool_call", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": "", - }, - } - ) - - elif chunk.event.payload.event_type == "step_complete": + if chunk.event.payload.event_type == "step_complete": for t in chunk.event.payload.step_details.tool_calls: yield format_stream_data( { diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 8ff286ad..9d7c30ca 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -860,14 +860,9 @@ def test_stream_build_event_step_progress_tool_call_str(): ) ) - result = next(stream_build_event(chunk, 0, {})) + results = list(stream_build_event(chunk, 0, {})) - assert result is not None - assert "data: " in result - assert '"event": "tool_call"' in result - assert '"token": "tool-called"' in result - assert '"role": "inference"' in result - assert '"id": 0' in result + assert len(results) == 0 def test_stream_build_event_step_progress_tool_call_tool_call(): @@ -892,14 +887,9 @@ def test_stream_build_event_step_progress_tool_call_tool_call(): ) ) - result = next(stream_build_event(chunk, 0, {})) + results = list(stream_build_event(chunk, 0, {})) - assert result is not None - assert "data: " in result - assert '"event": "tool_call"' in result - assert '"token": "my-tool"' in result - assert '"role": "inference"' in result - assert '"id": 0' in result + assert len(results) == 0 def test_stream_build_event_step_complete():