diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index fa67be71..8701aa6b 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -80,6 +80,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen available_input_shields: list[str], available_output_shields: list[str], conversation_id: str | None, + no_tools: bool = False, ) -> tuple[Agent, str]: """Get existing agent or create a new one with session persistence.""" if conversation_id is not None: @@ -99,7 +100,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen instructions=system_prompt, input_shields=available_input_shields if available_input_shields else [], output_shields=available_output_shields if available_output_shields else [], - tool_parser=GraniteToolParser.get_parser(model_id), + tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), enable_session_persistence=True, ) conversation_id = agent.create_session(get_suid()) @@ -288,36 +289,47 @@ def retrieve_response( # pylint: disable=too-many-locals available_input_shields, available_output_shields, query_request.conversation_id, + query_request.no_tools or False, ) - # preserve compatibility when mcp_headers is not provided - if mcp_headers is None: + # bypass tools and MCP servers if no_tools is True + if query_request.no_tools: mcp_headers = {} - mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration) - if not mcp_headers and token: - for mcp_server in configuration.mcp_servers: - mcp_headers[mcp_server.url] = { - "Authorization": f"Bearer {token}", - } - - agent.extra_headers = { - "X-LlamaStack-Provider-Data": json.dumps( - { - "mcp_headers": mcp_headers, - } - ), - } + agent.extra_headers = {} + toolgroups = None + else: + # preserve compatibility when mcp_headers is not provided + if mcp_headers is None: + mcp_headers = {} + mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration) + if not mcp_headers and token: + for mcp_server in configuration.mcp_servers: + mcp_headers[mcp_server.url] = { + "Authorization": f"Bearer {token}", + } + + agent.extra_headers = { + "X-LlamaStack-Provider-Data": json.dumps( + { + "mcp_headers": mcp_headers, + } + ), + } + + vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] + toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ + mcp_server.name for mcp_server in configuration.mcp_servers + ] + # Convert empty list to None for consistency with existing behavior + if not toolgroups: + toolgroups = None - vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ - mcp_server.name for mcp_server in configuration.mcp_servers - ] response = agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], session_id=conversation_id, documents=query_request.get_documents(), stream=False, - toolgroups=toolgroups or None, + toolgroups=toolgroups, ) # Check for validation errors in the response diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index eb94222d..d27388e2 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -58,6 +58,7 @@ async def get_agent( available_input_shields: list[str], available_output_shields: list[str], conversation_id: str | None, + no_tools: bool = False, ) -> tuple[AsyncAgent, str]: """Get existing agent or create a new one with session persistence.""" if conversation_id is not None: @@ -76,7 +77,7 @@ async def get_agent( instructions=system_prompt, input_shields=available_input_shields if available_input_shields else [], output_shields=available_output_shields if available_output_shields else [], - tool_parser=GraniteToolParser.get_parser(model_id), + tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), enable_session_persistence=True, ) conversation_id = await agent.create_session(get_suid()) @@ -532,41 +533,53 @@ async def retrieve_response( available_input_shields, available_output_shields, query_request.conversation_id, + query_request.no_tools or False, ) - # preserve compatibility when mcp_headers is not provided - if mcp_headers is None: + # bypass tools and MCP servers if no_tools is True + if query_request.no_tools: mcp_headers = {} + agent.extra_headers = {} + toolgroups = None + else: + # preserve compatibility when mcp_headers is not provided + if mcp_headers is None: + mcp_headers = {} - mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration) + mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration) - if not mcp_headers and token: - for mcp_server in configuration.mcp_servers: - mcp_headers[mcp_server.url] = { - "Authorization": f"Bearer {token}", - } + if not mcp_headers and token: + for mcp_server in configuration.mcp_servers: + mcp_headers[mcp_server.url] = { + "Authorization": f"Bearer {token}", + } - agent.extra_headers = { - "X-LlamaStack-Provider-Data": json.dumps( - { - "mcp_headers": mcp_headers, - } - ), - } + agent.extra_headers = { + "X-LlamaStack-Provider-Data": json.dumps( + { + "mcp_headers": mcp_headers, + } + ), + } + + logger.debug("Session ID: %s", conversation_id) + vector_db_ids = [ + vector_db.identifier for vector_db in await client.vector_dbs.list() + ] + toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ + mcp_server.name for mcp_server in configuration.mcp_servers + ] + # Convert empty list to None for consistency with existing behavior + if not toolgroups: + toolgroups = None logger.debug("Session ID: %s", conversation_id) - vector_db_ids = [ - vector_db.identifier for vector_db in await client.vector_dbs.list() - ] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ - mcp_server.name for mcp_server in configuration.mcp_servers - ] response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], session_id=conversation_id, documents=query_request.get_documents(), stream=True, - toolgroups=toolgroups or None, + toolgroups=toolgroups, ) return response, conversation_id diff --git a/src/models/requests.py b/src/models/requests.py index 3335861d..cea3b22c 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -69,6 +69,7 @@ class QueryRequest(BaseModel): model: The optional model. system_prompt: The optional system prompt. attachments: The optional attachments. + no_tools: Whether to bypass all tools and MCP servers (default: False). Example: ```python @@ -82,6 +83,7 @@ class QueryRequest(BaseModel): model: Optional[str] = None system_prompt: Optional[str] = None attachments: Optional[list[Attachment]] = None + no_tools: Optional[bool] = False # media_type is not used in 'lightspeed-stack' that only supports application/json. # the field is kept here to enable compatibility with 'road-core' clients. media_type: Optional[str] = None @@ -97,6 +99,7 @@ class QueryRequest(BaseModel): "provider": "openai", "model": "model-name", "system_prompt": "You are a helpful assistant", + "no_tools": False, "attachments": [ { "attachment_type": "log", diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 93952bca..03e32563 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -587,6 +587,7 @@ def __repr__(self): ["shield1", "input_shield2", "inout_shield4"], # available_input_shields ["output_shield3", "inout_shield4"], # available_output_shields None, # conversation_id + False, # no_tools ) mock_agent.create_turn.assert_called_once_with( @@ -745,6 +746,7 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): [], # available_input_shields [], # available_output_shields None, # conversation_id + False, # no_tools ) # Check that the agent's extra_headers property was set correctly @@ -809,6 +811,7 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc [], # available_input_shields [], # available_output_shields None, # conversation_id + False, # no_tools ) # Check that create_turn was called with the correct parameters @@ -881,6 +884,7 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers( [], # available_input_shields [], # available_output_shields None, # conversation_id + False, # no_tools ) expected_mcp_headers = { @@ -1382,3 +1386,301 @@ def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): ) assert mock_retrieve_response.call_args[0][3] == "auth_token_123" + + +def test_query_endpoint_handler_no_tools_true(mocker): + """Test the query endpoint handler with no_tools=True.""" + mock_client = mocker.Mock() + mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_lsc.return_value = mock_client + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + ] + + mock_config = mocker.Mock() + mock_config.user_data_collection_configuration.transcripts_disabled = True + mocker.patch("app.endpoints.query.configuration", mock_config) + + llm_response = "LLM answer without tools" + conversation_id = "fake_conversation_id" + query = "What is OpenStack?" + + mocker.patch( + "app.endpoints.query.retrieve_response", + return_value=(llm_response, conversation_id), + ) + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_provider_id"), + ) + mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + + query_request = QueryRequest(query=query, no_tools=True) + + response = query_endpoint_handler(query_request, auth=MOCK_AUTH) + + # Assert the response is as expected + assert response.response == llm_response + assert response.conversation_id == conversation_id + + +def test_query_endpoint_handler_no_tools_false(mocker): + """Test the query endpoint handler with no_tools=False (default behavior).""" + mock_client = mocker.Mock() + mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_lsc.return_value = mock_client + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + ] + + mock_config = mocker.Mock() + mock_config.user_data_collection_configuration.transcripts_disabled = True + mocker.patch("app.endpoints.query.configuration", mock_config) + + llm_response = "LLM answer with tools" + conversation_id = "fake_conversation_id" + query = "What is OpenStack?" + + mocker.patch( + "app.endpoints.query.retrieve_response", + return_value=(llm_response, conversation_id), + ) + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_provider_id"), + ) + mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + + query_request = QueryRequest(query=query, no_tools=False) + + response = query_endpoint_handler(query_request, auth=MOCK_AUTH) + + # Assert the response is as expected + assert response.response == llm_response + assert response.conversation_id == conversation_id + + +def test_retrieve_response_no_tools_bypasses_mcp_and_rag(prepare_agent_mocks, mocker): + """Test that retrieve_response bypasses MCP servers and RAG when no_tools=True.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client.shields.list.return_value = [] + mock_vector_db = mocker.Mock() + mock_vector_db.identifier = "VectorDB-1" + mock_client.vector_dbs.list.return_value = [mock_vector_db] + + # Mock configuration with MCP servers + mcp_servers = [ + ModelContextProtocolServer( + name="filesystem-server", url="http://localhost:3000" + ), + ] + mock_config = mocker.Mock() + mock_config.mcp_servers = mcp_servers + mocker.patch("app.endpoints.query.configuration", mock_config) + mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + ) + + query_request = QueryRequest(query="What is OpenStack?", no_tools=True) + model_id = "fake_model_id" + access_token = "test_token" + + response, conversation_id = retrieve_response( + mock_client, model_id, query_request, access_token + ) + + assert response == "LLM answer" + assert conversation_id == "fake_session_id" + + # Verify that agent.extra_headers is empty (no MCP headers) + assert mock_agent.extra_headers == {} + + # Verify that create_turn was called with toolgroups=None + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", + documents=[], + stream=False, + toolgroups=None, + ) + + +def test_retrieve_response_no_tools_false_preserves_functionality( + prepare_agent_mocks, mocker +): + """Test that retrieve_response preserves normal functionality when no_tools=False.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client.shields.list.return_value = [] + mock_vector_db = mocker.Mock() + mock_vector_db.identifier = "VectorDB-1" + mock_client.vector_dbs.list.return_value = [mock_vector_db] + + # Mock configuration with MCP servers + mcp_servers = [ + ModelContextProtocolServer( + name="filesystem-server", url="http://localhost:3000" + ), + ] + mock_config = mocker.Mock() + mock_config.mcp_servers = mcp_servers + mocker.patch("app.endpoints.query.configuration", mock_config) + mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + ) + + query_request = QueryRequest(query="What is OpenStack?", no_tools=False) + model_id = "fake_model_id" + access_token = "test_token" + + response, conversation_id = retrieve_response( + mock_client, model_id, query_request, access_token + ) + + assert response == "LLM answer" + assert conversation_id == "fake_session_id" + + # Verify that agent.extra_headers contains MCP headers + expected_extra_headers = { + "X-LlamaStack-Provider-Data": json.dumps( + { + "mcp_headers": { + "http://localhost:3000": {"Authorization": "Bearer test_token"}, + } + } + ) + } + assert mock_agent.extra_headers == expected_extra_headers + + # Verify that create_turn was called with RAG and MCP toolgroups + expected_toolgroups = get_rag_toolgroups(["VectorDB-1"]) + ["filesystem-server"] + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", + documents=[], + stream=False, + toolgroups=expected_toolgroups, + ) + + +def test_get_agent_no_tools_no_parser(setup_configuration, prepare_agent_mocks, mocker): + """Test get_agent function sets tool_parser=None when no_tools=True.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.query.Agent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.query.configuration", setup_configuration) + + # Call function with no_tools=True + result_agent, result_conversation_id = get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + no_tools=True, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with tool_parser=None + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +def test_get_agent_no_tools_false_preserves_parser( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function preserves tool_parser when no_tools=False.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.query.Agent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") + + # Mock GraniteToolParser + mock_parser = mocker.Mock() + mock_granite_parser = mocker.patch("app.endpoints.query.GraniteToolParser") + mock_granite_parser.get_parser.return_value = mock_parser + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.query.configuration", setup_configuration) + + # Call function with no_tools=False + result_agent, result_conversation_id = get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + no_tools=False, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with the proper tool_parser + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=mock_parser, + enable_session_persistence=True, + ) + + +def test_no_tools_parameter_backward_compatibility(): + """Test that default behavior is unchanged when no_tools parameter is not specified.""" + # This test ensures that existing code that doesn't specify no_tools continues to work + query_request = QueryRequest(query="What is OpenStack?") + + # Verify default value + assert query_request.no_tools is False + + # Test that QueryRequest can be created without no_tools parameter + query_request_minimal = QueryRequest(query="Simple query") + assert query_request_minimal.no_tools is False diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 893d2bed..8ff286ad 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -43,6 +43,7 @@ get_agent, _agent_cache, ) + from models.requests import QueryRequest, Attachment from models.config import ModelContextProtocolServer @@ -529,6 +530,7 @@ def __repr__(self): ["shield1", "input_shield2", "inout_shield4"], # available_input_shields ["output_shield3", "inout_shield4"], # available_output_shields None, # conversation_id + False, # no_tools ) mock_agent.create_turn.assert_called_once_with( @@ -1037,6 +1039,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): [], # available_input_shields [], # available_output_shields None, # conversation_id + False, # no_tools ) # Check that the agent's extra_headers property was set correctly @@ -1104,6 +1107,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( [], # available_input_shields [], # available_output_shields None, # conversation_id + False, # no_tools ) # Check that the agent's extra_headers property was set correctly (empty mcp_headers) @@ -1182,6 +1186,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): [], # available_input_shields [], # available_output_shields None, # conversation_id + False, # no_tools ) expected_mcp_headers = { @@ -1559,3 +1564,311 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): ) assert mock_retrieve_response.call_args[0][3] == "auth_token_123" + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_handler_no_tools_true(mocker): + """Test the streaming query endpoint handler with no_tools=True.""" + mock_client = mocker.AsyncMock() + mock_async_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") + mock_async_lsc.return_value = mock_client + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + ] + + mock_config = mocker.Mock() + mock_config.user_data_collection_configuration.transcripts_disabled = True + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) + + # Mock the streaming response + mock_streaming_response = mocker.AsyncMock() + mock_streaming_response.__aiter__.return_value = iter([]) + + mocker.patch( + "app.endpoints.streaming_query.retrieve_response", + return_value=(mock_streaming_response, "test_conversation_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_provider_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False + ) + + query_request = QueryRequest(query="What is OpenStack?", no_tools=True) + + response = await streaming_query_endpoint_handler( + None, query_request, auth=MOCK_AUTH + ) + + # Assert the response is a StreamingResponse + assert isinstance(response, StreamingResponse) + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_handler_no_tools_false(mocker): + """Test the streaming query endpoint handler with no_tools=False (default behavior).""" + mock_client = mocker.AsyncMock() + mock_async_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") + mock_async_lsc.return_value = mock_client + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + ] + + mock_config = mocker.Mock() + mock_config.user_data_collection_configuration.transcripts_disabled = True + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) + + # Mock the streaming response + mock_streaming_response = mocker.AsyncMock() + mock_streaming_response.__aiter__.return_value = iter([]) + + mocker.patch( + "app.endpoints.streaming_query.retrieve_response", + return_value=(mock_streaming_response, "test_conversation_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_provider_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False + ) + + query_request = QueryRequest(query="What is OpenStack?", no_tools=False) + + response = await streaming_query_endpoint_handler( + None, query_request, auth=MOCK_AUTH + ) + + # Assert the response is a StreamingResponse + assert isinstance(response, StreamingResponse) + + +@pytest.mark.asyncio +async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( + prepare_agent_mocks, mocker +): + """Test that retrieve_response bypasses MCP servers and RAG when no_tools=True.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client.shields.list.return_value = [] + mock_vector_db = mocker.Mock() + mock_vector_db.identifier = "VectorDB-1" + mock_client.vector_dbs.list.return_value = [mock_vector_db] + + # Mock configuration with MCP servers + mcp_servers = [ + ModelContextProtocolServer( + name="filesystem-server", url="http://localhost:3000" + ), + ] + mock_config = mocker.Mock() + mock_config.mcp_servers = mcp_servers + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) + mocker.patch( + "app.endpoints.streaming_query.get_agent", + return_value=(mock_agent, "fake_session_id"), + ) + + query_request = QueryRequest(query="What is OpenStack?", no_tools=True) + model_id = "fake_model_id" + access_token = "test_token" + + response, conversation_id = await retrieve_response( + mock_client, model_id, query_request, access_token + ) + + assert response is not None + assert conversation_id == "fake_session_id" + + # Verify that agent.extra_headers is empty (no MCP headers) + assert mock_agent.extra_headers == {} + + # Verify that create_turn was called with toolgroups=None + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", + documents=[], + stream=True, + toolgroups=None, + ) + + +@pytest.mark.asyncio +async def test_retrieve_response_no_tools_false_preserves_functionality( + prepare_agent_mocks, mocker +): + """Test that retrieve_response preserves normal functionality when no_tools=False.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client.shields.list.return_value = [] + mock_vector_db = mocker.Mock() + mock_vector_db.identifier = "VectorDB-1" + mock_client.vector_dbs.list.return_value = [mock_vector_db] + + # Mock configuration with MCP servers + mcp_servers = [ + ModelContextProtocolServer( + name="filesystem-server", url="http://localhost:3000" + ), + ] + mock_config = mocker.Mock() + mock_config.mcp_servers = mcp_servers + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) + mocker.patch( + "app.endpoints.streaming_query.get_agent", + return_value=(mock_agent, "fake_session_id"), + ) + + query_request = QueryRequest(query="What is OpenStack?", no_tools=False) + model_id = "fake_model_id" + access_token = "test_token" + + response, conversation_id = await retrieve_response( + mock_client, model_id, query_request, access_token + ) + + assert response is not None + assert conversation_id == "fake_session_id" + + # Verify that agent.extra_headers contains MCP headers + expected_extra_headers = { + "X-LlamaStack-Provider-Data": json.dumps( + { + "mcp_headers": { + "http://localhost:3000": {"Authorization": "Bearer test_token"}, + } + } + ) + } + assert mock_agent.extra_headers == expected_extra_headers + + expected_toolgroups = get_rag_toolgroups(["VectorDB-1"]) + ["filesystem-server"] + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", + documents=[], + stream=True, + toolgroups=expected_toolgroups, + ) + + +@pytest.mark.asyncio +async def test_get_agent_no_tools_no_parser( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function sets tool_parser=None when no_tools=True.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch( + "app.endpoints.streaming_query.get_suid", return_value="new_session_id" + ) + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + + # Call function with no_tools=True + result_agent, result_conversation_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + no_tools=True, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with tool_parser=None + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_no_tools_false_preserves_parser( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function preserves tool_parser when no_tools=False.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch( + "app.endpoints.streaming_query.get_suid", return_value="new_session_id" + ) + + # Mock GraniteToolParser + mock_parser = mocker.Mock() + mock_granite_parser = mocker.patch( + "app.endpoints.streaming_query.GraniteToolParser" + ) + mock_granite_parser.get_parser.return_value = mock_parser + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + + # Call function with no_tools=False + result_agent, result_conversation_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + no_tools=False, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with the proper tool_parser + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=mock_parser, + enable_session_persistence=True, + )