diff --git a/src/agents/tool.py b/src/agents/tool.py index 39db129b7..b122cc1fe 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -185,6 +185,9 @@ class FunctionTool: tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None """Optional list of output guardrails to run after invoking this tool.""" + custom_metadata: dict[str, Any] | None = None + """Optional custom metadata to attach to this tool.""" + def __post_init__(self): if self.strict_json_schema: self.params_json_schema = ensure_strict_json_schema(self.params_json_schema) @@ -211,6 +214,9 @@ class FileSearchTool: filters: Filters | None = None """A filter to apply based on file attributes.""" + custom_metadata: dict[str, Any] | None = None + """Optional custom metadata to attach to this tool.""" + @property def name(self): return "file_search" @@ -231,6 +237,9 @@ class WebSearchTool: search_context_size: Literal["low", "medium", "high"] = "medium" """The amount of context to use for the search.""" + custom_metadata: dict[str, Any] | None = None + """Optional custom metadata to attach to this tool.""" + @property def name(self): return "web_search" @@ -248,6 +257,9 @@ class ComputerTool: on_safety_check: Callable[[ComputerToolSafetyCheckData], MaybeAwaitable[bool]] | None = None """Optional callback to acknowledge computer tool safety checks.""" + custom_metadata: dict[str, Any] | None = None + """Optional custom metadata to attach to this tool.""" + @property def name(self): return "computer_use_preview" @@ -313,6 +325,9 @@ class HostedMCPTool: provided, you will need to manually add approvals/rejections to the input and call `Runner.run(...)` again.""" + custom_metadata: dict[str, Any] | None = None + """Optional custom metadata to attach to this tool.""" + @property def name(self): return "hosted_mcp" @@ -325,6 +340,9 @@ class CodeInterpreterTool: tool_config: CodeInterpreter """The tool config, which includes the container and other settings.""" + custom_metadata: dict[str, Any] | None = None + """Optional custom metadata to attach to this tool.""" + @property def name(self): return "code_interpreter" @@ -337,6 +355,9 @@ class ImageGenerationTool: tool_config: ImageGeneration """The tool config, which image generation settings.""" + custom_metadata: dict[str, Any] | None = None + """Optional custom metadata to attach to this tool.""" + @property def name(self): return "image_generation" @@ -368,6 +389,9 @@ class LocalShellTool: executor: LocalShellExecutor """A function that executes a command on a shell.""" + custom_metadata: dict[str, Any] | None = None + """Optional custom metadata to attach to this tool.""" + @property def name(self): return "local_shell" @@ -405,6 +429,7 @@ def function_tool( failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, + custom_metadata: dict[str, Any] | None = None, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -420,6 +445,7 @@ def function_tool( failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, + custom_metadata: dict[str, Any] | None = None, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -435,6 +461,7 @@ def function_tool( failure_error_function: ToolErrorFunction | None = default_tool_error_function, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, + custom_metadata: dict[str, Any] | None = None, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: @@ -466,6 +493,7 @@ def function_tool( is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run context and agent and returns whether the tool is enabled. Disabled tools are hidden from the LLM at runtime. + custom_metadata: Optional metadata to attach to the resulting tool instance. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: @@ -556,6 +584,7 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any: on_invoke_tool=_on_invoke_tool, strict_json_schema=strict_mode, is_enabled=is_enabled, + custom_metadata=custom_metadata, ) # If func is actually a callable, we were used as @function_tool with no parentheses diff --git a/tests/test_tool_custom_metadata.py b/tests/test_tool_custom_metadata.py new file mode 100644 index 000000000..578c4b77d --- /dev/null +++ b/tests/test_tool_custom_metadata.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +from typing import Any, Callable, Dict, List, cast + +import pytest + +from agents.agent import Agent +from agents.computer import Computer +from agents.lifecycle import AgentHooks +from agents.run import Runner +from agents.tool import ( + CodeInterpreterTool, + ComputerTool, + FileSearchTool, + FunctionTool, + HostedMCPTool, + ImageGenerationTool, + LocalShellCommandRequest, + LocalShellTool, + function_tool, + Tool, + WebSearchTool, +) +from tests.fake_model import FakeModel +from tests.test_responses import get_function_tool, get_function_tool_call, get_text_message + + +async def _noop_invoke(context: Any, params_json: str) -> str: + return "ok" + + +class _DummyComputer(Computer): + @property + def environment(self) -> str: + return "windows" + + @property + def dimensions(self) -> tuple[int, int]: + return (800, 600) + + def screenshot(self) -> str: + return "" + + def click(self, x: int, y: int, button: str) -> None: + return None + + def double_click(self, x: int, y: int) -> None: + return None + + def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None: + return None + + def type(self, text: str) -> None: + return None + + def wait(self) -> None: + return None + + def move(self, x: int, y: int) -> None: + return None + + def keypress(self, keys: List[str]) -> None: + return None + + def drag(self, path: List[tuple[int, int]]) -> None: + return None + + +def _make_function_tool(with_metadata: bool) -> FunctionTool: + kwargs: Dict[str, Any] = {} + if with_metadata: + kwargs["custom_metadata"] = {"key": "value"} + + return FunctionTool( + name="func", + description="desc", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=_noop_invoke, + **kwargs, + ) + + +def _make_file_search_tool(with_metadata: bool) -> FileSearchTool: + kwargs: Dict[str, Any] = {} + if with_metadata: + kwargs["custom_metadata"] = {"key": "value"} + return FileSearchTool(vector_store_ids=["vs"], **kwargs) + + +def _make_web_search_tool(with_metadata: bool) -> WebSearchTool: + kwargs: Dict[str, Any] = {} + if with_metadata: + kwargs["custom_metadata"] = {"key": "value"} + return WebSearchTool(**kwargs) + + +def _make_computer_tool(with_metadata: bool) -> ComputerTool: + kwargs: Dict[str, Any] = {} + if with_metadata: + kwargs["custom_metadata"] = {"key": "value"} + return ComputerTool(computer=_DummyComputer(), **kwargs) + + +def _make_hosted_mcp_tool(with_metadata: bool) -> HostedMCPTool: + kwargs: Dict[str, Any] = {} + if with_metadata: + kwargs["custom_metadata"] = {"key": "value"} + tool_config = cast(Any, {"server_url": "https://example.com"}) + return HostedMCPTool(tool_config=tool_config, **kwargs) + + +def _make_code_interpreter_tool(with_metadata: bool) -> CodeInterpreterTool: + kwargs: Dict[str, Any] = {} + if with_metadata: + kwargs["custom_metadata"] = {"key": "value"} + tool_config = cast(Any, {"runtime": "python"}) + return CodeInterpreterTool(tool_config=tool_config, **kwargs) + + +def _make_image_generation_tool(with_metadata: bool) -> ImageGenerationTool: + kwargs: Dict[str, Any] = {} + if with_metadata: + kwargs["custom_metadata"] = {"key": "value"} + tool_config = cast(Any, {"model": "image"}) + return ImageGenerationTool(tool_config=tool_config, **kwargs) + + +def _make_local_shell_tool(with_metadata: bool) -> LocalShellTool: + kwargs: Dict[str, Any] = {} + if with_metadata: + kwargs["custom_metadata"] = {"key": "value"} + + def _executor(request: LocalShellCommandRequest) -> str: + return "executed" + + return LocalShellTool(executor=_executor, **kwargs) + + +@pytest.mark.parametrize( + "factory", + [ + _make_function_tool, + _make_file_search_tool, + _make_web_search_tool, + _make_computer_tool, + _make_hosted_mcp_tool, + _make_code_interpreter_tool, + _make_image_generation_tool, + _make_local_shell_tool, + ], +) +def test_custom_metadata_defaults_to_none(factory: Callable[[bool], Any]) -> None: + tool = factory(False) + assert tool.custom_metadata is None + + +@pytest.mark.parametrize( + "factory", + [ + _make_function_tool, + _make_file_search_tool, + _make_web_search_tool, + _make_computer_tool, + _make_hosted_mcp_tool, + _make_code_interpreter_tool, + _make_image_generation_tool, + _make_local_shell_tool, + ], +) +def test_custom_metadata_can_be_provided(factory: Callable[[bool], Any]) -> None: + tool = factory(True) + assert tool.custom_metadata == {"key": "value"} + + +def test_function_tool_decorator_allows_custom_metadata() -> None: + metadata = {"foo": "bar"} + + @function_tool(custom_metadata=metadata) + def _decorated() -> str: + return "ok" + + assert _decorated.custom_metadata is metadata + + +def test_function_tool_direct_call_allows_custom_metadata() -> None: + metadata = {"alpha": "beta"} + + def _fn() -> str: + return "ok" + + tool = function_tool(_fn, custom_metadata=metadata) + assert tool.custom_metadata is metadata + + +class _MetadataCapturingHooks(AgentHooks): + def __init__(self) -> None: + self.start_metadata: list[dict[str, Any] | None] = [] + self.end_metadata: list[dict[str, Any] | None] = [] + + async def on_tool_start( + self, + context: Any, + agent: Agent[Any], + tool: Tool, + ) -> None: + self.start_metadata.append(tool.custom_metadata) + + async def on_tool_end( + self, + context: Any, + agent: Agent[Any], + tool: Tool, + result: str, + ) -> None: + self.end_metadata.append(tool.custom_metadata) + + +@pytest.mark.asyncio +async def test_custom_metadata_available_in_hooks() -> None: + hooks = _MetadataCapturingHooks() + fake_model = FakeModel() + + tool = get_function_tool("custom_tool", return_value="tool result") + metadata = {"source": "unit_test"} + tool.custom_metadata = metadata + + agent = Agent(name="metadata_agent", model=fake_model, tools=[tool], hooks=hooks) + + fake_model.add_multiple_turn_outputs( + [ + [get_function_tool_call("custom_tool", "{}")] , + [get_text_message("Final response")], + ] + ) + + result = await Runner.run(agent, "metadata input") + assert result.final_output == "Final response" + assert hooks.start_metadata == [metadata] + assert hooks.end_metadata == [metadata]