Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ class FunctionTool:
tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None
"""Optional list of output guardrails to run after invoking this tool."""

custom_metadata: dict[str, Any] | None = None
"""Optional custom metadata to attach to this tool."""

def __post_init__(self):
if self.strict_json_schema:
self.params_json_schema = ensure_strict_json_schema(self.params_json_schema)
Expand All @@ -211,6 +214,9 @@ class FileSearchTool:
filters: Filters | None = None
"""A filter to apply based on file attributes."""

custom_metadata: dict[str, Any] | None = None
"""Optional custom metadata to attach to this tool."""

@property
def name(self):
return "file_search"
Expand All @@ -231,6 +237,9 @@ class WebSearchTool:
search_context_size: Literal["low", "medium", "high"] = "medium"
"""The amount of context to use for the search."""

custom_metadata: dict[str, Any] | None = None
"""Optional custom metadata to attach to this tool."""

@property
def name(self):
return "web_search"
Expand All @@ -248,6 +257,9 @@ class ComputerTool:
on_safety_check: Callable[[ComputerToolSafetyCheckData], MaybeAwaitable[bool]] | None = None
"""Optional callback to acknowledge computer tool safety checks."""

custom_metadata: dict[str, Any] | None = None
"""Optional custom metadata to attach to this tool."""

@property
def name(self):
return "computer_use_preview"
Expand Down Expand Up @@ -313,6 +325,9 @@ class HostedMCPTool:
provided, you will need to manually add approvals/rejections to the input and call
`Runner.run(...)` again."""

custom_metadata: dict[str, Any] | None = None
"""Optional custom metadata to attach to this tool."""

@property
def name(self):
return "hosted_mcp"
Expand All @@ -325,6 +340,9 @@ class CodeInterpreterTool:
tool_config: CodeInterpreter
"""The tool config, which includes the container and other settings."""

custom_metadata: dict[str, Any] | None = None
"""Optional custom metadata to attach to this tool."""

@property
def name(self):
return "code_interpreter"
Expand All @@ -337,6 +355,9 @@ class ImageGenerationTool:
tool_config: ImageGeneration
"""The tool config, which image generation settings."""

custom_metadata: dict[str, Any] | None = None
"""Optional custom metadata to attach to this tool."""

@property
def name(self):
return "image_generation"
Expand Down Expand Up @@ -368,6 +389,9 @@ class LocalShellTool:
executor: LocalShellExecutor
"""A function that executes a command on a shell."""

custom_metadata: dict[str, Any] | None = None
"""Optional custom metadata to attach to this tool."""

@property
def name(self):
return "local_shell"
Expand Down
218 changes: 218 additions & 0 deletions tests/test_tool_custom_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from __future__ import annotations

from typing import Any, Callable, Dict, List, cast

import pytest

from agents.agent import Agent
from agents.computer import Computer
from agents.lifecycle import AgentHooks
from agents.run import Runner
from agents.tool import (
CodeInterpreterTool,
ComputerTool,
FileSearchTool,
FunctionTool,
HostedMCPTool,
ImageGenerationTool,
LocalShellCommandRequest,
LocalShellTool,
Tool,
WebSearchTool,
)
from tests.fake_model import FakeModel
from tests.test_responses import get_function_tool, get_function_tool_call, get_text_message


async def _noop_invoke(context: Any, params_json: str) -> str:
return "ok"


class _DummyComputer(Computer):
@property
def environment(self) -> str:
return "windows"

@property
def dimensions(self) -> tuple[int, int]:
return (800, 600)

def screenshot(self) -> str:
return ""

def click(self, x: int, y: int, button: str) -> None:
return None

def double_click(self, x: int, y: int) -> None:
return None

def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
return None

def type(self, text: str) -> None:
return None

def wait(self) -> None:
return None

def move(self, x: int, y: int) -> None:
return None

def keypress(self, keys: List[str]) -> None:
return None

def drag(self, path: List[tuple[int, int]]) -> None:
return None


def _make_function_tool(with_metadata: bool) -> FunctionTool:
kwargs: Dict[str, Any] = {}
if with_metadata:
kwargs["custom_metadata"] = {"key": "value"}

return FunctionTool(
name="func",
description="desc",
params_json_schema={"type": "object", "properties": {}},
on_invoke_tool=_noop_invoke,
**kwargs,
)


def _make_file_search_tool(with_metadata: bool) -> FileSearchTool:
kwargs: Dict[str, Any] = {}
if with_metadata:
kwargs["custom_metadata"] = {"key": "value"}
return FileSearchTool(vector_store_ids=["vs"], **kwargs)


def _make_web_search_tool(with_metadata: bool) -> WebSearchTool:
kwargs: Dict[str, Any] = {}
if with_metadata:
kwargs["custom_metadata"] = {"key": "value"}
return WebSearchTool(**kwargs)


def _make_computer_tool(with_metadata: bool) -> ComputerTool:
kwargs: Dict[str, Any] = {}
if with_metadata:
kwargs["custom_metadata"] = {"key": "value"}
return ComputerTool(computer=_DummyComputer(), **kwargs)


def _make_hosted_mcp_tool(with_metadata: bool) -> HostedMCPTool:
kwargs: Dict[str, Any] = {}
if with_metadata:
kwargs["custom_metadata"] = {"key": "value"}
tool_config = cast(Any, {"server_url": "https://example.com"})
return HostedMCPTool(tool_config=tool_config, **kwargs)


def _make_code_interpreter_tool(with_metadata: bool) -> CodeInterpreterTool:
kwargs: Dict[str, Any] = {}
if with_metadata:
kwargs["custom_metadata"] = {"key": "value"}
tool_config = cast(Any, {"runtime": "python"})
return CodeInterpreterTool(tool_config=tool_config, **kwargs)


def _make_image_generation_tool(with_metadata: bool) -> ImageGenerationTool:
kwargs: Dict[str, Any] = {}
if with_metadata:
kwargs["custom_metadata"] = {"key": "value"}
tool_config = cast(Any, {"model": "image"})
return ImageGenerationTool(tool_config=tool_config, **kwargs)


def _make_local_shell_tool(with_metadata: bool) -> LocalShellTool:
kwargs: Dict[str, Any] = {}
if with_metadata:
kwargs["custom_metadata"] = {"key": "value"}

def _executor(request: LocalShellCommandRequest) -> str:
return "executed"

return LocalShellTool(executor=_executor, **kwargs)


@pytest.mark.parametrize(
"factory",
[
_make_function_tool,
_make_file_search_tool,
_make_web_search_tool,
_make_computer_tool,
_make_hosted_mcp_tool,
_make_code_interpreter_tool,
_make_image_generation_tool,
_make_local_shell_tool,
],
)
def test_custom_metadata_defaults_to_none(factory: Callable[[bool], Any]) -> None:
tool = factory(False)
assert tool.custom_metadata is None


@pytest.mark.parametrize(
"factory",
[
_make_function_tool,
_make_file_search_tool,
_make_web_search_tool,
_make_computer_tool,
_make_hosted_mcp_tool,
_make_code_interpreter_tool,
_make_image_generation_tool,
_make_local_shell_tool,
],
)
def test_custom_metadata_can_be_provided(factory: Callable[[bool], Any]) -> None:
tool = factory(True)
assert tool.custom_metadata == {"key": "value"}


class _MetadataCapturingHooks(AgentHooks):
def __init__(self) -> None:
self.start_metadata: list[dict[str, Any] | None] = []
self.end_metadata: list[dict[str, Any] | None] = []

async def on_tool_start(
self,
context: Any,
agent: Agent[Any],
tool: Tool,
) -> None:
self.start_metadata.append(tool.custom_metadata)

async def on_tool_end(
self,
context: Any,
agent: Agent[Any],
tool: Tool,
result: str,
) -> None:
self.end_metadata.append(tool.custom_metadata)


@pytest.mark.asyncio
async def test_custom_metadata_available_in_hooks() -> None:
hooks = _MetadataCapturingHooks()
fake_model = FakeModel()

tool = get_function_tool("custom_tool", return_value="tool result")
metadata = {"source": "unit_test"}
tool.custom_metadata = metadata

agent = Agent(name="metadata_agent", model=fake_model, tools=[tool], hooks=hooks)

fake_model.add_multiple_turn_outputs(
[
[get_function_tool_call("custom_tool", "{}")] ,
[get_text_message("Final response")],
]
)

result = await Runner.run(agent, "metadata input")
assert result.final_output == "Final response"
assert hooks.start_metadata == [metadata]
assert hooks.end_metadata == [metadata]