Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions client-sdks/stainless/openapi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6626,6 +6626,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input:
type: array
items:
Expand Down Expand Up @@ -6984,6 +6989,11 @@ components:
(Optional) Additional fields to include in the response.
max_infer_iters:
type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false
required:
- input
Expand Down Expand Up @@ -7065,6 +7075,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false
required:
- created_at
Expand Down
15 changes: 15 additions & 0 deletions docs/static/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5910,6 +5910,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input:
type: array
items:
Expand Down Expand Up @@ -6268,6 +6273,11 @@ components:
(Optional) Additional fields to include in the response.
max_infer_iters:
type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false
required:
- input
Expand Down Expand Up @@ -6349,6 +6359,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false
required:
- created_at
Expand Down
15 changes: 15 additions & 0 deletions docs/static/stainless-llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6626,6 +6626,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input:
type: array
items:
Expand Down Expand Up @@ -6984,6 +6989,11 @@ components:
(Optional) Additional fields to include in the response.
max_infer_iters:
type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false
required:
- input
Expand Down Expand Up @@ -7065,6 +7075,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false
required:
- created_at
Expand Down
2 changes: 2 additions & 0 deletions src/llama_stack/apis/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ async def create_openai_response(
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
),
] = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a model response.

Expand All @@ -97,6 +98,7 @@ async def create_openai_response(
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response.
:returns: An OpenAIResponseObject.
"""
...
Expand Down
2 changes: 2 additions & 0 deletions src/llama_stack/apis/agents/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ class OpenAIResponseObject(BaseModel):
:param truncation: (Optional) Truncation strategy applied to the response
:param usage: (Optional) Token usage information for the response
:param instructions: (Optional) System message inserted into the model's context
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response
"""

created_at: int
Expand All @@ -615,6 +616,7 @@ class OpenAIResponseObject(BaseModel):
truncation: str | None = None
usage: OpenAIResponseUsage | None = None
instructions: str | None = None
max_tool_calls: int | None = None


@json_schema_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ async def create_openai_response(
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
result = await self.openai_responses_impl.create_openai_response(
Expand All @@ -119,6 +120,7 @@ async def create_openai_response(
include,
max_infer_iters,
guardrails,
max_tool_calls,
)
return result # type: ignore[no-any-return]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ async def create_openai_response(
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[str | ResponseGuardrailSpec] | None = None,
max_tool_calls: int | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
Expand All @@ -270,6 +271,9 @@ async def create_openai_response(
if not conversation.startswith("conv_"):
raise InvalidConversationIdError(conversation)

if max_tool_calls is not None and max_tool_calls < 1:
raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1")

stream_gen = self._create_streaming_response(
input=input,
conversation=conversation,
Expand All @@ -282,6 +286,7 @@ async def create_openai_response(
tools=tools,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
max_tool_calls=max_tool_calls,
)

if stream:
Expand Down Expand Up @@ -331,6 +336,7 @@ async def _create_streaming_response(
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
max_tool_calls: int | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types
Expand Down Expand Up @@ -373,6 +379,7 @@ async def _create_streaming_response(
safety_api=self.safety_api,
guardrail_ids=guardrail_ids,
instructions=instructions,
max_tool_calls=max_tool_calls,
)

# Stream the response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
safety_api,
guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None,
max_tool_calls: int | None = None,
):
self.inference_api = inference_api
self.ctx = ctx
Expand All @@ -126,6 +127,10 @@ def __init__(
self.safety_api = safety_api
self.guardrail_ids = guardrail_ids or []
self.prompt = prompt
# System message that is inserted into the model's context
self.instructions = instructions
# Max number of total calls to built-in tools that can be processed in a response
self.max_tool_calls = max_tool_calls
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
Expand All @@ -139,8 +144,8 @@ def __init__(
self.accumulated_usage: OpenAIResponseUsage | None = None
# Track if we've sent a refusal response
self.violation_detected = False
# system message that is inserted into the model's context
self.instructions = instructions
# Track total calls made to built-in tools
self.accumulated_builtin_tool_calls = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we just name it accumulated_tool_calls? why the prefix "builtin" here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that cause confusion around function tools? As in, to make it clear that we are only tracking non-function tools.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@s-akhtar-baig that makes sense, thanks!


async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""
Expand Down Expand Up @@ -186,6 +191,7 @@ def _snapshot_response(
usage=self.accumulated_usage,
instructions=self.instructions,
prompt=self.prompt,
max_tool_calls=self.max_tool_calls,
)

async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
Expand Down Expand Up @@ -894,6 +900,11 @@ async def _coordinate_tool_execution(
"""Coordinate execution of both function and non-function tool calls."""
# Execute non-function tool calls
for tool_call in non_function_tool_calls:
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
break

# Find the item_id for this tool call
matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items():
Expand Down Expand Up @@ -974,6 +985,9 @@ async def _coordinate_tool_execution(
if tool_response_message:
next_turn_messages.append(tool_response_message)

# Track number of calls made to built-in and mcp tools
self.accumulated_builtin_tool_calls += 1

# Execute function tool calls (client-side)
for tool_call in function_tool_calls:
# Find the item_id for this tool call from our tracking dictionary
Expand Down
Loading
Loading