From 835e6c60ad2cf2d0c2a94e37672d621fa5e3084c Mon Sep 17 00:00:00 2001 From: Shabana Baig <43451943+s-akhtar-baig@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:22:27 -0500 Subject: [PATCH] Implement the 'max_tool_calls' parameter for the Responses API Test max_tool_calls with builtin and mcp tools Update input prompt for more consistent tool calling Resolve merge conflicts Update integration test Handle review comments --- client-sdks/stainless/openapi.yml | 15 ++ docs/static/llama-stack-spec.yaml | 15 ++ docs/static/stainless-llama-stack-spec.yaml | 15 ++ src/llama_stack/apis/agents/agents.py | 2 + .../apis/agents/openai_responses.py | 2 + .../inline/agents/meta_reference/agents.py | 2 + .../responses/openai_responses.py | 7 + .../meta_reference/responses/streaming.py | 18 +- .../agents/test_openai_responses.py | 166 ++++++++++++++++++ 9 files changed, 240 insertions(+), 2 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index d8159be624..4e503a121e 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -6882,6 +6882,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response input: type: array items: @@ -7240,6 +7245,11 @@ components: (Optional) Additional fields to include in the response. max_infer_iters: type: integer + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response. additionalProperties: false required: - input @@ -7321,6 +7331,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response additionalProperties: false required: - created_at diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index ea7fd6eecc..60bd96f3c5 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -6166,6 +6166,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response input: type: array items: @@ -6524,6 +6529,11 @@ components: (Optional) Additional fields to include in the response. max_infer_iters: type: integer + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response. additionalProperties: false required: - input @@ -6605,6 +6615,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response additionalProperties: false required: - created_at diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index d8159be624..4e503a121e 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -6882,6 +6882,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response input: type: array items: @@ -7240,6 +7245,11 @@ components: (Optional) Additional fields to include in the response. max_infer_iters: type: integer + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response. additionalProperties: false required: - input @@ -7321,6 +7331,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response additionalProperties: false required: - created_at diff --git a/src/llama_stack/apis/agents/agents.py b/src/llama_stack/apis/agents/agents.py index cadef2edc0..09687ef330 100644 --- a/src/llama_stack/apis/agents/agents.py +++ b/src/llama_stack/apis/agents/agents.py @@ -87,6 +87,7 @@ async def create_openai_response( "List of guardrails to apply during response generation. Guardrails provide safety and content moderation." ), ] = None, + max_tool_calls: int | None = None, ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: """Create a model response. @@ -97,6 +98,7 @@ async def create_openai_response( :param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation. :param include: (Optional) Additional fields to include in the response. :param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications. + :param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response. :returns: An OpenAIResponseObject. """ ... diff --git a/src/llama_stack/apis/agents/openai_responses.py b/src/llama_stack/apis/agents/openai_responses.py index a38d1cba67..16657ab32c 100644 --- a/src/llama_stack/apis/agents/openai_responses.py +++ b/src/llama_stack/apis/agents/openai_responses.py @@ -594,6 +594,7 @@ class OpenAIResponseObject(BaseModel): :param truncation: (Optional) Truncation strategy applied to the response :param usage: (Optional) Token usage information for the response :param instructions: (Optional) System message inserted into the model's context + :param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response """ created_at: int @@ -615,6 +616,7 @@ class OpenAIResponseObject(BaseModel): truncation: str | None = None usage: OpenAIResponseUsage | None = None instructions: str | None = None + max_tool_calls: int | None = None @json_schema_type diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agents.py b/src/llama_stack/providers/inline/agents/meta_reference/agents.py index 7141d58bcc..880e0b6802 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -102,6 +102,7 @@ async def create_openai_response( include: list[str] | None = None, max_infer_iters: int | None = 10, guardrails: list[ResponseGuardrail] | None = None, + max_tool_calls: int | None = None, ) -> OpenAIResponseObject: assert self.openai_responses_impl is not None, "OpenAI responses not initialized" result = await self.openai_responses_impl.create_openai_response( @@ -119,6 +120,7 @@ async def create_openai_response( include, max_infer_iters, guardrails, + max_tool_calls, ) return result # type: ignore[no-any-return] diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 933cfe963a..ed7f959c04 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -255,6 +255,7 @@ async def create_openai_response( include: list[str] | None = None, max_infer_iters: int | None = 10, guardrails: list[str | ResponseGuardrailSpec] | None = None, + max_tool_calls: int | None = None, ): stream = bool(stream) text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text @@ -270,6 +271,9 @@ async def create_openai_response( if not conversation.startswith("conv_"): raise InvalidConversationIdError(conversation) + if max_tool_calls is not None and max_tool_calls < 1: + raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1") + stream_gen = self._create_streaming_response( input=input, conversation=conversation, @@ -282,6 +286,7 @@ async def create_openai_response( tools=tools, max_infer_iters=max_infer_iters, guardrail_ids=guardrail_ids, + max_tool_calls=max_tool_calls, ) if stream: @@ -331,6 +336,7 @@ async def _create_streaming_response( tools: list[OpenAIResponseInputTool] | None = None, max_infer_iters: int | None = 10, guardrail_ids: list[str] | None = None, + max_tool_calls: int | None = None, ) -> AsyncIterator[OpenAIResponseObjectStream]: # These should never be None when called from create_openai_response (which sets defaults) # but we assert here to help mypy understand the types @@ -373,6 +379,7 @@ async def _create_streaming_response( safety_api=self.safety_api, guardrail_ids=guardrail_ids, instructions=instructions, + max_tool_calls=max_tool_calls, ) # Stream the response diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index ef56034204..c16bc8df30 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -115,6 +115,7 @@ def __init__( safety_api, guardrail_ids: list[str] | None = None, prompt: OpenAIResponsePrompt | None = None, + max_tool_calls: int | None = None, ): self.inference_api = inference_api self.ctx = ctx @@ -126,6 +127,10 @@ def __init__( self.safety_api = safety_api self.guardrail_ids = guardrail_ids or [] self.prompt = prompt + # System message that is inserted into the model's context + self.instructions = instructions + # Max number of total calls to built-in tools that can be processed in a response + self.max_tool_calls = max_tool_calls self.sequence_number = 0 # Store MCP tool mapping that gets built during tool processing self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ( @@ -139,8 +144,8 @@ def __init__( self.accumulated_usage: OpenAIResponseUsage | None = None # Track if we've sent a refusal response self.violation_detected = False - # system message that is inserted into the model's context - self.instructions = instructions + # Track total calls made to built-in tools + self.accumulated_builtin_tool_calls = 0 async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream: """Create a refusal response to replace streaming content.""" @@ -186,6 +191,7 @@ def _snapshot_response( usage=self.accumulated_usage, instructions=self.instructions, prompt=self.prompt, + max_tool_calls=self.max_tool_calls, ) async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: @@ -894,6 +900,11 @@ async def _coordinate_tool_execution( """Coordinate execution of both function and non-function tool calls.""" # Execute non-function tool calls for tool_call in non_function_tool_calls: + # Check if total calls made to built-in and mcp tools exceed max_tool_calls + if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls: + logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.") + break + # Find the item_id for this tool call matching_item_id = None for index, item_id in completion_result_data.tool_call_item_ids.items(): @@ -974,6 +985,9 @@ async def _coordinate_tool_execution( if tool_response_message: next_turn_messages.append(tool_response_message) + # Track number of calls made to built-in and mcp tools + self.accumulated_builtin_tool_calls += 1 + # Execute function tool calls (client-side) for tool_call in function_tool_calls: # Find the item_id for this tool call from our tracking dictionary diff --git a/tests/integration/agents/test_openai_responses.py b/tests/integration/agents/test_openai_responses.py index d413d52016..057cee774a 100644 --- a/tests/integration/agents/test_openai_responses.py +++ b/tests/integration/agents/test_openai_responses.py @@ -516,3 +516,169 @@ def test_response_with_instructions(openai_client, client_with_models, text_mode # Verify instructions from previous response was not carried over to the next response assert response_with_instructions2.instructions == instructions2 + + +@pytest.mark.skip(reason="Tool calling is not reliable.") +def test_max_tool_calls_with_function_tools(openai_client, client_with_models, text_model_id): + """Test handling of max_tool_calls with function tools in responses.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI responses are not supported when testing with library client yet.") + + client = openai_client + max_tool_calls = 1 + + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get weather information for a specified location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name (e.g., 'New York', 'London')", + }, + }, + }, + }, + { + "type": "function", + "name": "get_time", + "description": "Get current time for a specified location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name (e.g., 'New York', 'London')", + }, + }, + }, + }, + ] + + # First create a response that triggers function tools + response = client.responses.create( + model=text_model_id, + input="Can you tell me the weather in Paris and the current time?", + tools=tools, + stream=False, + max_tool_calls=max_tool_calls, + ) + + # Verify we got two function calls and that the max_tool_calls do not affect function tools + assert len(response.output) == 2 + assert response.output[0].type == "function_call" + assert response.output[0].name == "get_weather" + assert response.output[0].status == "completed" + assert response.output[1].type == "function_call" + assert response.output[1].name == "get_time" + assert response.output[0].status == "completed" + + # Verify we have a valid max_tool_calls field + assert response.max_tool_calls == max_tool_calls + + +def test_max_tool_calls_invalid(openai_client, client_with_models, text_model_id): + """Test handling of invalid max_tool_calls in responses.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI responses are not supported when testing with library client yet.") + + client = openai_client + + input = "Search for today's top technology news." + invalid_max_tool_calls = 0 + tools = [ + {"type": "web_search"}, + ] + + # Create a response with an invalid max_tool_calls value i.e. 0 + # Handle ValueError from LLS and BadRequestError from OpenAI client + with pytest.raises((ValueError, BadRequestError)) as excinfo: + client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=invalid_max_tool_calls, + ) + + error_message = str(excinfo.value) + assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, ( + f"Expected error message about invalid max_tool_calls, got: {error_message}" + ) + + +def test_max_tool_calls_with_builtin_tools(openai_client, client_with_models, text_model_id): + """Test handling of max_tool_calls with built-in tools in responses.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI responses are not supported when testing with library client yet.") + + client = openai_client + + input = "Search for today's top technology and a positive news story. You MUST make exactly two separate web search calls." + max_tool_calls = [1, 5] + tools = [ + {"type": "web_search"}, + ] + + # First create a response that triggers web_search tools without max_tool_calls + response = client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + ) + + # Verify we got two web search calls followed by a message + assert len(response.output) == 3 + assert response.output[0].type == "web_search_call" + assert response.output[0].status == "completed" + assert response.output[1].type == "web_search_call" + assert response.output[1].status == "completed" + assert response.output[2].type == "message" + assert response.output[2].status == "completed" + assert response.output[2].role == "assistant" + + # Next create a response that triggers web_search tools with max_tool_calls set to 1 + response_2 = client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=max_tool_calls[0], + ) + + # Verify we got one web search tool call followed by a message + assert len(response_2.output) == 2 + assert response_2.output[0].type == "web_search_call" + assert response_2.output[0].status == "completed" + assert response_2.output[1].type == "message" + assert response_2.output[1].status == "completed" + assert response_2.output[1].role == "assistant" + + # Verify we have a valid max_tool_calls field + assert response_2.max_tool_calls == max_tool_calls[0] + + # Finally create a response that triggers web_search tools with max_tool_calls set to 5 + response_3 = client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=max_tool_calls[1], + ) + + # Verify we got two web search calls followed by a message + assert len(response_3.output) == 3 + assert response_3.output[0].type == "web_search_call" + assert response_3.output[0].status == "completed" + assert response_3.output[1].type == "web_search_call" + assert response_3.output[1].status == "completed" + assert response_3.output[2].type == "message" + assert response_3.output[2].status == "completed" + assert response_3.output[2].role == "assistant" + + # Verify we have a valid max_tool_calls field + assert response_3.max_tool_calls == max_tool_calls[1]