diff --git a/docs/tools.md b/docs/tools.md index 4e9a20d32..2550d3dae 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -100,7 +100,7 @@ for tool in agent.tools: 1. You can use any Python types as arguments to your functions, and the function can be sync or async. 2. Docstrings, if present, are used to capture descriptions and argument descriptions -3. Functions can optionally take the `context` (must be the first argument). You can also set overrides, like the name of the tool, description, which docstring style to use, etc. +3. Functions can optionally take the `context` (must be the first argument). You can also set overrides, like the name of the tool, description, which docstring style to use, and exclude specific parameters from the schema, etc. 4. You can pass the decorated functions to the list of tools. ??? note "Expand to see output" @@ -284,6 +284,45 @@ async def run_my_agent() -> str: return str(result.final_output) ``` +## Excluding parameters from the schema + +Sometimes, you might want to exclude certain parameters from the JSON schema that's presented to the LLM, while still making them available to your function with their default values. This can be useful for: + +- Keeping implementation details hidden from the LLM +- Simplifying the tool interface presented to the model +- Maintaining backward compatibility when adding new parameters +- Supporting internal parameters that should always use default values + +You can do this using the `exclude_params` parameter of the `@function_tool` decorator: + +```python +from typing import Optional +from agents import function_tool, RunContextWrapper + +@function_tool(exclude_params=["timestamp", "internal_id"]) +def search_database( + query: str, + limit: int = 10, + timestamp: Optional[str] = None, + internal_id: Optional[str] = None +) -> str: + """ + Search the database for records matching the query. + + Args: + query: The search query string + limit: Maximum number of results to return + timestamp: The timestamp to use for the search (hidden from schema) + internal_id: Internal tracking ID for telemetry (hidden from schema) + """ + # Implementation... +``` + +In this example: +- The LLM will only see `query` and `limit` parameters in the tool schema +- `timestamp` and `internal_id` will be automatically set to their default values when the function runs +- All excluded parameters must have default values (either `None` or a specific value) + ## Handling errors in function tools When you create a function tool via `@function_tool`, you can pass a `failure_error_function`. This is a function that provides an error response to the LLM in case the tool call crashes. diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index 0e5868965..1eba541df 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -45,6 +45,9 @@ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]: positional_args: list[Any] = [] keyword_args: dict[str, Any] = {} seen_var_positional = False + + # Get excluded parameter defaults if they exist + excluded_param_defaults = getattr(self.params_pydantic_model, "__excluded_param_defaults__", {}) # Use enumerate() so we can skip the first parameter if it's context. for idx, (name, param) in enumerate(self.signature.parameters.items()): @@ -52,7 +55,12 @@ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]: if self.takes_context and idx == 0: continue - value = getattr(data, name, None) + # For excluded parameters, use their default value + if name in excluded_param_defaults: + value = excluded_param_defaults[name] + else: + value = getattr(data, name, None) + if param.kind == param.VAR_POSITIONAL: # e.g. *args: extend positional args and mark that *args is now seen positional_args.extend(value or []) @@ -190,6 +198,7 @@ def function_schema( description_override: str | None = None, use_docstring_info: bool = True, strict_json_schema: bool = True, + exclude_params: list[str] | None = None, ) -> FuncSchema: """ Given a python function, extracts a `FuncSchema` from it, capturing the name, description, @@ -208,6 +217,9 @@ def function_schema( the schema adheres to the "strict" standard the OpenAI API expects. We **strongly** recommend setting this to True, as it increases the likelihood of the LLM providing correct JSON input. + exclude_params: If provided, these parameters will be excluded from the JSON schema + presented to the LLM. The parameters will still be available to the function with + their default values. All excluded parameters must have default values. Returns: A `FuncSchema` object containing the function's name, description, parameter descriptions, @@ -231,11 +243,24 @@ def function_schema( takes_context = False filtered_params = [] + # Store default values for excluded parameters + excluded_param_defaults = {} + if params: first_name, first_param = params[0] # Prefer the evaluated type hint if available ann = type_hints.get(first_name, first_param.annotation) - if ann != inspect._empty: + + # Check if this parameter should be excluded + if exclude_params and first_name in exclude_params: + # Ensure the parameter has a default value + if first_param.default is inspect._empty: + raise UserError( + f"Parameter '{first_name}' specified in exclude_params must have a default value" + ) + # Store default value + excluded_param_defaults[first_name] = first_param.default + elif ann != inspect._empty: origin = get_origin(ann) or ann if origin is RunContextWrapper: takes_context = True # Mark that the function takes context @@ -246,6 +271,17 @@ def function_schema( # For parameters other than the first, raise error if any use RunContextWrapper. for name, param in params[1:]: + # Check if this parameter should be excluded + if exclude_params and name in exclude_params: + # Ensure the parameter has a default value + if param.default is inspect._empty: + raise UserError( + f"Parameter '{name}' specified in exclude_params must have a default value" + ) + # Store default value + excluded_param_defaults[name] = param.default + continue + ann = type_hints.get(name, param.annotation) if ann != inspect._empty: origin = get_origin(ann) or ann @@ -326,6 +362,9 @@ def function_schema( # 3. Dynamically build a Pydantic model dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields) + + # Store excluded parameter defaults in the model for later use + setattr(dynamic_model, "__excluded_param_defaults__", excluded_param_defaults) # 4. Build JSON schema from that model json_schema = dynamic_model.model_json_schema() diff --git a/src/agents/tool.py b/src/agents/tool.py index fd5a21c89..503d6b447 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -262,6 +262,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, + exclude_params: list[str] | None = None, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -276,6 +277,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, + exclude_params: list[str] | None = None, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -290,6 +292,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = default_tool_error_function, strict_mode: bool = True, + exclude_params: list[str] | None = None, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: @@ -318,9 +321,26 @@ def function_tool( If False, it allows non-strict JSON schemas. For example, if a parameter has a default value, it will be optional, additional properties are allowed, etc. See here for more: https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas + exclude_params: If provided, these parameters will be excluded from the JSON schema + presented to the LLM. The parameters will still be available to the function with + their default values. All excluded parameters must have default values. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: + # Check that all excluded parameters have default values + if exclude_params: + sig = inspect.signature(the_func) + for param_name in exclude_params: + if param_name not in sig.parameters: + raise UserError( + f"Parameter '{param_name}' specified in exclude_params doesn't exist in function {the_func.__name__}" + ) + param = sig.parameters[param_name] + if param.default is inspect._empty: + raise UserError( + f"Parameter '{param_name}' specified in exclude_params must have a default value" + ) + schema = function_schema( func=the_func, name_override=name_override, @@ -328,6 +348,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: docstring_style=docstring_style, use_docstring_info=use_docstring_info, strict_json_schema=strict_mode, + exclude_params=exclude_params, ) async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any: diff --git a/tests/test_function_schema.py b/tests/test_function_schema.py index 5618d8ae9..bbd23be83 100644 --- a/tests/test_function_schema.py +++ b/tests/test_function_schema.py @@ -439,3 +439,67 @@ def func_with_mapping(test_one: Mapping[str, int]) -> str: with pytest.raises(UserError): function_schema(func_with_mapping) + + +def function_with_optional_params(a: int, b: int = 5, c: str = "default"): + """Function with multiple optional parameters.""" + return f"{a}-{b}-{c}" + + +def test_exclude_params_feature(): + """Test the exclude_params feature works correctly.""" + # Test excluding a single optional parameter + func_schema = function_schema( + function_with_optional_params, + exclude_params=["c"], + ) + + # Verify 'c' is not in the schema properties + assert "c" not in func_schema.params_json_schema.get("properties", {}) + + # Verify the excluded parameter defaults are stored + excluded_defaults = getattr(func_schema.params_pydantic_model, "__excluded_param_defaults__", {}) + assert "c" in excluded_defaults + assert excluded_defaults["c"] == "default" + + # Test function still works correctly with excluded parameter + valid_input = {"a": 10, "b": 20} + parsed = func_schema.params_pydantic_model(**valid_input) + args, kwargs_dict = func_schema.to_call_args(parsed) + result = function_with_optional_params(*args, **kwargs_dict) + assert result == "10-20-default" # 'c' should use its default value + + # Test excluding multiple parameters + func_schema_multi = function_schema( + function_with_optional_params, + exclude_params=["b", "c"], + ) + + # Verify both 'b' and 'c' are not in the schema properties + assert "b" not in func_schema_multi.params_json_schema.get("properties", {}) + assert "c" not in func_schema_multi.params_json_schema.get("properties", {}) + + # Test function still works correctly with multiple excluded parameters + valid_input = {"a": 10} + parsed = func_schema_multi.params_pydantic_model(**valid_input) + args, kwargs_dict = func_schema_multi.to_call_args(parsed) + result = function_with_optional_params(*args, **kwargs_dict) + assert result == "10-5-default" # 'b' and 'c' should use their default values + + +def function_with_required_param(a: int, b: str): + """Function with required parameters only.""" + return f"{a}-{b}" + + +def test_exclude_params_requires_default_value(): + """Test that excluding a parameter without a default value raises an error.""" + # Attempt to exclude a parameter without a default value + with pytest.raises(UserError) as excinfo: + function_schema( + function_with_required_param, + exclude_params=["b"], + ) + + # Check the error message + assert "must have a default value" in str(excinfo.value) diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index 3b52788fb..010f614af 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -233,3 +233,39 @@ async def test_extract_descriptions_from_docstring(): "additionalProperties": False, } ) + + +@function_tool(exclude_params=["timestamp"]) +def function_with_excluded_param( + city: str, country: str = "US", timestamp: Optional[str] = None +) -> str: + """Get the weather for a given city with timestamp. + + Args: + city: The city to get the weather for. + country: The country the city is in. + timestamp: The timestamp for the weather data (hidden from schema). + """ + time_str = f" at {timestamp}" if timestamp else "" + return f"The weather in {city}, {country}{time_str} is sunny." + + +@pytest.mark.asyncio +async def test_exclude_params_from_schema(): + """Test that excluded parameters are not included in the schema.""" + tool = function_with_excluded_param + + # Check that the parameter is not in the schema + assert "timestamp" not in tool.params_json_schema.get("properties", {}) + + # Check that only non-excluded parameters are required + assert set(tool.params_json_schema.get("required", [])) == {"city"} + + # Test function still works with excluded parameter + input_data = {"city": "Seattle", "country": "US"} + output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) + assert output == "The weather in Seattle, US is sunny." + + # Test function works when we supply a default excluded parameter value in the code + function_result = function_with_excluded_param("Seattle", "US", "2023-05-29T12:00:00Z") + assert function_result == "The weather in Seattle, US at 2023-05-29T12:00:00Z is sunny."