Skip to content

Commit 296853b

Browse files
committed
feat(tools): Add support for Annotated type hints in @tool decorator
1 parent de802fb commit 296853b

File tree

2 files changed

+297
-6
lines changed

2 files changed

+297
-6
lines changed

src/strands/tools/decorator.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def my_tool(param1: str, param2: int = 42) -> dict:
4444
import functools
4545
import inspect
4646
import logging
47+
from copy import copy
4748
from typing import (
49+
Annotated,
4850
Any,
4951
Callable,
5052
Generic,
@@ -54,12 +56,15 @@ def my_tool(param1: str, param2: int = 42) -> dict:
5456
TypeVar,
5557
Union,
5658
cast,
59+
get_args,
60+
get_origin,
5761
get_type_hints,
5862
overload,
5963
)
6064

6165
import docstring_parser
6266
from pydantic import BaseModel, Field, create_model
67+
from pydantic.fields import FieldInfo
6368
from typing_extensions import override
6469

6570
from ..interrupt import InterruptException
@@ -97,7 +102,12 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
97102
"""
98103
self.func = func
99104
self.signature = inspect.signature(func)
100-
self.type_hints = get_type_hints(func)
105+
# Preserve Annotated extras when possible (Python 3.9+ / 3.10+ support include_extras)
106+
try:
107+
self.type_hints = get_type_hints(func, include_extras=True)
108+
except TypeError:
109+
# Older Python versions / typing implementations may not accept include_extras
110+
self.type_hints = get_type_hints(func)
101111
self._context_param = context_param
102112

103113
self._validate_signature()
@@ -114,6 +124,32 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
114124
# Create a Pydantic model for validation
115125
self.input_model = self._create_input_model()
116126

127+
def _extract_annotated_metadata(self, annotation: Any) -> tuple[Any, Optional[Any]]:
128+
"""Extract type and metadata from Annotated type hint.
129+
130+
Returns:
131+
(actual_type, metadata) where metadata is either:
132+
- a string description
133+
- a pydantic.fields.FieldInfo instance (from Field(...))
134+
- None if no Annotated extras were found
135+
"""
136+
if get_origin(annotation) is Annotated:
137+
args = get_args(annotation)
138+
actual_type = args[0] # Keep the type as-is (including Optional[T])
139+
140+
# Look through metadata for description
141+
for meta in args[1:]:
142+
if isinstance(meta, str):
143+
return actual_type, meta
144+
if isinstance(meta, FieldInfo):
145+
return actual_type, meta
146+
147+
# Annotated but no useful metadata
148+
return actual_type, None
149+
150+
# Not annotated
151+
return annotation, None
152+
117153
def _validate_signature(self) -> None:
118154
"""Verify that ToolContext is used correctly in the function signature."""
119155
for param in self.signature.parameters.values():
@@ -146,13 +182,38 @@ def _create_input_model(self) -> Type[BaseModel]:
146182
if self._is_special_parameter(name):
147183
continue
148184

149-
# Get parameter type and default
185+
# Get parameter type hint and any Annotated metadata
150186
param_type = self.type_hints.get(name, Any)
187+
actual_type, annotated_meta = self._extract_annotated_metadata(param_type)
188+
189+
# Determine parameter default value
151190
default = ... if param.default is inspect.Parameter.empty else param.default
152-
description = self.param_descriptions.get(name, f"Parameter {name}")
153191

154-
# Create Field with description and default
155-
field_definitions[name] = (param_type, Field(default=default, description=description))
192+
# Determine description (priority: Annotated > docstring > generic)
193+
description: str
194+
if isinstance(annotated_meta, str):
195+
description = annotated_meta
196+
elif isinstance(annotated_meta, FieldInfo) and annotated_meta.description is not None:
197+
description = annotated_meta.description
198+
elif name in self.param_descriptions:
199+
description = self.param_descriptions[name]
200+
else:
201+
description = f"Parameter {name}"
202+
203+
# Create Field definition for create_model
204+
if isinstance(annotated_meta, FieldInfo):
205+
# Create a defensive copy to avoid mutating a shared FieldInfo instance.
206+
field_info_copy = copy(annotated_meta)
207+
field_info_copy.description = description
208+
209+
# Update default if specified in the function signature.
210+
if default is not ...:
211+
field_info_copy.default = default
212+
213+
field_definitions[name] = (actual_type, field_info_copy)
214+
else:
215+
# For non-FieldInfo metadata, create a new Field.
216+
field_definitions[name] = (actual_type, Field(default=default, description=description))
156217

157218
# Create model name based on function name
158219
model_name = f"{self.func.__name__.capitalize()}Tool"

tests/strands/tools/test_decorator.py

Lines changed: 231 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
"""
44

55
from asyncio import Queue
6-
from typing import Any, AsyncGenerator, Dict, Optional, Union
6+
from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union
77
from unittest.mock import MagicMock
88

99
import pytest
10+
from pydantic import Field
1011

1112
import strands
1213
from strands import Agent
@@ -1450,3 +1451,232 @@ def test_function_tool_metadata_validate_signature_missing_context_config():
14501451
@strands.tool
14511452
def my_tool(tool_context: ToolContext):
14521453
pass
1454+
1455+
1456+
def test_tool_decorator_annotated_string_description():
1457+
"""Test tool decorator with Annotated type hints for descriptions."""
1458+
1459+
@strands.tool
1460+
def annotated_tool(
1461+
name: Annotated[str, "The user's full name"],
1462+
age: Annotated[int, "The user's age in years"],
1463+
city: str, # No annotation - should use docstring or generic
1464+
) -> str:
1465+
"""Tool with annotated parameters.
1466+
1467+
Args:
1468+
city: The user's city (from docstring)
1469+
"""
1470+
return f"{name}, {age}, {city}"
1471+
1472+
spec = annotated_tool.tool_spec
1473+
schema = spec["inputSchema"]["json"]
1474+
1475+
# Check that annotated descriptions are used
1476+
assert schema["properties"]["name"]["description"] == "The user's full name"
1477+
assert schema["properties"]["age"]["description"] == "The user's age in years"
1478+
1479+
# Check that docstring is still used for non-annotated params
1480+
assert schema["properties"]["city"]["description"] == "The user's city (from docstring)"
1481+
1482+
# Verify all are required
1483+
assert set(schema["required"]) == {"name", "age", "city"}
1484+
1485+
1486+
def test_tool_decorator_annotated_pydantic_field_constraints():
1487+
"""Test tool decorator with Pydantic Field in Annotated."""
1488+
1489+
@strands.tool
1490+
def field_annotated_tool(
1491+
email: Annotated[str, Field(description="User's email address", pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$")],
1492+
score: Annotated[int, Field(description="Score between 0-100", ge=0, le=100)] = 50,
1493+
) -> str:
1494+
"""Tool with Pydantic Field annotations."""
1495+
return f"{email}: {score}"
1496+
1497+
spec = field_annotated_tool.tool_spec
1498+
schema = spec["inputSchema"]["json"]
1499+
1500+
# Check descriptions from Field
1501+
assert schema["properties"]["email"]["description"] == "User's email address"
1502+
assert schema["properties"]["score"]["description"] == "Score between 0-100"
1503+
1504+
# Check that constraints are preserved
1505+
assert schema["properties"]["score"]["minimum"] == 0
1506+
assert schema["properties"]["score"]["maximum"] == 100
1507+
1508+
# Check required fields
1509+
assert "email" in schema["required"]
1510+
assert "score" not in schema["required"] # Has default
1511+
1512+
1513+
def test_tool_decorator_annotated_overrides_docstring():
1514+
"""Test that Annotated descriptions override docstring descriptions."""
1515+
1516+
@strands.tool
1517+
def override_tool(param: Annotated[str, "Description from annotation"]) -> str:
1518+
"""Tool with both annotation and docstring.
1519+
1520+
Args:
1521+
param: Description from docstring (should be overridden)
1522+
"""
1523+
return param
1524+
1525+
spec = override_tool.tool_spec
1526+
schema = spec["inputSchema"]["json"]
1527+
1528+
# Annotated description should win
1529+
assert schema["properties"]["param"]["description"] == "Description from annotation"
1530+
1531+
1532+
def test_tool_decorator_annotated_optional_type():
1533+
"""Test tool with Optional types in Annotated."""
1534+
1535+
@strands.tool
1536+
def optional_annotated_tool(
1537+
required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None
1538+
) -> str:
1539+
"""Tool with optional annotated parameter."""
1540+
return f"{required}, {optional}"
1541+
1542+
spec = optional_annotated_tool.tool_spec
1543+
schema = spec["inputSchema"]["json"]
1544+
1545+
# Check descriptions
1546+
assert schema["properties"]["required"]["description"] == "Required parameter"
1547+
assert schema["properties"]["optional"]["description"] == "Optional parameter"
1548+
1549+
# Check required list
1550+
assert "required" in schema["required"]
1551+
assert "optional" not in schema["required"]
1552+
1553+
1554+
def test_tool_decorator_annotated_complex_types():
1555+
"""Test tool with complex types in Annotated."""
1556+
1557+
@strands.tool
1558+
def complex_annotated_tool(
1559+
tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"]
1560+
) -> str:
1561+
"""Tool with complex annotated types."""
1562+
return f"Tags: {len(tags)}, Config: {len(config)}"
1563+
1564+
spec = complex_annotated_tool.tool_spec
1565+
schema = spec["inputSchema"]["json"]
1566+
1567+
# Check descriptions
1568+
assert schema["properties"]["tags"]["description"] == "List of tag strings"
1569+
assert schema["properties"]["config"]["description"] == "Configuration dictionary"
1570+
1571+
# Check types are preserved
1572+
assert schema["properties"]["tags"]["type"] == "array"
1573+
assert schema["properties"]["config"]["type"] == "object"
1574+
1575+
1576+
def test_tool_decorator_annotated_mixed_styles():
1577+
"""Test tool with mixed annotation styles."""
1578+
1579+
@strands.tool
1580+
def mixed_tool(
1581+
plain: str,
1582+
annotated_str: Annotated[str, "String description"],
1583+
annotated_field: Annotated[int, Field(description="Field description", ge=0)],
1584+
docstring_only: int,
1585+
) -> str:
1586+
"""Tool with mixed parameter styles.
1587+
1588+
Args:
1589+
plain: Plain parameter description
1590+
docstring_only: Docstring description for this param
1591+
"""
1592+
return "mixed"
1593+
1594+
spec = mixed_tool.tool_spec
1595+
schema = spec["inputSchema"]["json"]
1596+
1597+
# Check each style works correctly
1598+
assert schema["properties"]["plain"]["description"] == "Plain parameter description"
1599+
assert schema["properties"]["annotated_str"]["description"] == "String description"
1600+
assert schema["properties"]["annotated_field"]["description"] == "Field description"
1601+
assert schema["properties"]["docstring_only"]["description"] == "Docstring description for this param"
1602+
1603+
1604+
@pytest.mark.asyncio
1605+
async def test_tool_decorator_annotated_execution(alist):
1606+
"""Test that annotated tools execute correctly."""
1607+
1608+
@strands.tool
1609+
def execution_test(name: Annotated[str, "User name"], count: Annotated[int, "Number of times"] = 1) -> str:
1610+
"""Test execution with annotations."""
1611+
return f"Hello {name} " * count
1612+
1613+
# Test tool use
1614+
tool_use = {"toolUseId": "test-id", "input": {"name": "Alice", "count": 2}}
1615+
stream = execution_test.stream(tool_use, {})
1616+
1617+
result = (await alist(stream))[-1]
1618+
assert result["tool_result"]["status"] == "success"
1619+
assert "Hello Alice Hello Alice" in result["tool_result"]["content"][0]["text"]
1620+
1621+
# Test direct call
1622+
direct_result = execution_test("Bob", 3)
1623+
assert direct_result == "Hello Bob Hello Bob Hello Bob "
1624+
1625+
1626+
def test_tool_decorator_annotated_no_description_fallback():
1627+
"""Test that Annotated without description falls back to docstring."""
1628+
1629+
@strands.tool
1630+
def no_desc_annotated(
1631+
param: Annotated[str, Field()], # Field without description
1632+
) -> str:
1633+
"""Tool with Annotated but no description.
1634+
1635+
Args:
1636+
param: Docstring description
1637+
"""
1638+
return param
1639+
1640+
spec = no_desc_annotated.tool_spec
1641+
schema = spec["inputSchema"]["json"]
1642+
1643+
# Should fall back to docstring
1644+
assert schema["properties"]["param"]["description"] == "Docstring description"
1645+
1646+
1647+
def test_tool_decorator_annotated_empty_string_description():
1648+
"""Test handling of empty string descriptions in Annotated."""
1649+
1650+
@strands.tool
1651+
def empty_desc_tool(
1652+
param: Annotated[str, ""], # Empty string description
1653+
) -> str:
1654+
"""Tool with empty annotation description.
1655+
1656+
Args:
1657+
param: Docstring description
1658+
"""
1659+
return param
1660+
1661+
spec = empty_desc_tool.tool_spec
1662+
schema = spec["inputSchema"]["json"]
1663+
1664+
# Empty string is still a valid description, should not fall back
1665+
assert schema["properties"]["param"]["description"] == ""
1666+
1667+
1668+
@pytest.mark.asyncio
1669+
async def test_tool_decorator_annotated_validation_error(alist):
1670+
"""Test that validation works correctly with annotated parameters."""
1671+
1672+
@strands.tool
1673+
def validation_tool(age: Annotated[int, "User age"]) -> str:
1674+
"""Tool for validation testing."""
1675+
return f"Age: {age}"
1676+
1677+
# Test with wrong type
1678+
tool_use = {"toolUseId": "test-id", "input": {"age": "not an int"}}
1679+
stream = validation_tool.stream(tool_use, {})
1680+
1681+
result = (await alist(stream))[-1]
1682+
assert result["tool_result"]["status"] == "error"

0 commit comments

Comments
 (0)