Skip to content

Commit 6cf4f7e

Browse files
authored
fix(tool/decorator): validate ToolContext parameter name and raise clear error (#1028)
1 parent c3e5f6b commit 6cf4f7e

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

src/strands/tools/decorator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
9999
self.type_hints = get_type_hints(func)
100100
self._context_param = context_param
101101

102+
self._validate_signature()
103+
102104
# Parse the docstring with docstring_parser
103105
doc_str = inspect.getdoc(func) or ""
104106
self.doc = docstring_parser.parse(doc_str)
@@ -111,6 +113,20 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
111113
# Create a Pydantic model for validation
112114
self.input_model = self._create_input_model()
113115

116+
def _validate_signature(self) -> None:
117+
"""Verify that ToolContext is used correctly in the function signature."""
118+
for param in self.signature.parameters.values():
119+
if param.annotation is ToolContext:
120+
if self._context_param is None:
121+
raise ValueError("@tool(context) must be set if passing in ToolContext param")
122+
123+
if param.name != self._context_param:
124+
raise ValueError(
125+
f"param_name=<{param.name}> | ToolContext param must be named '{self._context_param}'"
126+
)
127+
# Found the parameter, no need to check further
128+
break
129+
114130
def _create_input_model(self) -> Type[BaseModel]:
115131
"""Create a Pydantic model from function signature for input validation.
116132

tests/strands/tools/test_decorator.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,3 +1363,27 @@ async def async_generator() -> AsyncGenerator:
13631363
]
13641364

13651365
assert act_results == exp_results
1366+
1367+
1368+
def test_function_tool_metadata_validate_signature_default_context_name_mismatch():
1369+
with pytest.raises(ValueError, match=r"param_name=<context> | ToolContext param must be named 'tool_context'"):
1370+
1371+
@strands.tool(context=True)
1372+
def my_tool(context: ToolContext):
1373+
pass
1374+
1375+
1376+
def test_function_tool_metadata_validate_signature_custom_context_name_mismatch():
1377+
with pytest.raises(ValueError, match=r"param_name=<tool_context> | ToolContext param must be named 'my_context'"):
1378+
1379+
@strands.tool(context="my_context")
1380+
def my_tool(tool_context: ToolContext):
1381+
pass
1382+
1383+
1384+
def test_function_tool_metadata_validate_signature_missing_context_config():
1385+
with pytest.raises(ValueError, match=r"@tool\(context\) must be set if passing in ToolContext param"):
1386+
1387+
@strands.tool
1388+
def my_tool(tool_context: ToolContext):
1389+
pass

0 commit comments

Comments
 (0)