Skip to content

Commit bdea552

Browse files
committed
feat(tools): refactor Annotated metadata extraction in decorator
1 parent 8637653 commit bdea552

File tree

2 files changed

+65
-55
lines changed

2 files changed

+65
-55
lines changed

src/strands/tools/decorator.py

Lines changed: 48 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -102,48 +102,73 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
102102
"""
103103
self.func = func
104104
self.signature = inspect.signature(func)
105+
# include_extras=True is key for reading Annotated metadata
105106
self.type_hints = get_type_hints(func, include_extras=True)
106107
self._context_param = context_param
107108

108109
self._validate_signature()
109110

110-
# Parse the docstring with docstring_parser
111+
# Parse the docstring once for all parameters
111112
doc_str = inspect.getdoc(func) or ""
112113
self.doc = docstring_parser.parse(doc_str)
113-
114-
# Get parameter descriptions from parsed docstring
115-
self.param_descriptions = {
116-
param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params
117-
}
114+
self.param_descriptions = {param.arg_name: param.description for param in self.doc.params if param.description}
118115

119116
# Create a Pydantic model for validation
120117
self.input_model = self._create_input_model()
121118

122-
def _extract_annotated_metadata(self, annotation: Any) -> tuple[Any, Optional[Any]]:
123-
"""Extract type and metadata from Annotated type hint.
119+
def _extract_annotated_metadata(
120+
self, annotation: Any, param_name: str, param_default: Any
121+
) -> tuple[Any, FieldInfo]:
122+
"""Extract type and create FieldInfo from Annotated type hint.
124123
125124
Returns:
126-
(actual_type, metadata) where metadata is either:
127-
- a string description
128-
- a pydantic.fields.FieldInfo instance (from Field(...))
129-
- None if no Annotated extras were found
125+
(actual_type, field_info) where field_info is always a FieldInfo instance
130126
"""
127+
actual_type = annotation
128+
field_info: FieldInfo | None = None
129+
description: str | None = None
130+
131131
if get_origin(annotation) is Annotated:
132132
args = get_args(annotation)
133-
actual_type = args[0] # Keep the type as-is (including Optional[T])
133+
actual_type = args[0]
134134

135-
# Look through metadata for description
135+
# Look through metadata for FieldInfo and string descriptions
136136
for meta in args[1:]:
137-
if isinstance(meta, str):
138-
return actual_type, meta
139137
if isinstance(meta, FieldInfo):
140-
return actual_type, meta
138+
field_info = meta
139+
elif isinstance(meta, str):
140+
description = meta
141+
142+
# Determine Final Description
143+
# Priority: 1. Annotated string, 2. FieldInfo description, 3. Docstring
144+
final_description = description
145+
146+
# An empty string is a valid description; only fall back if no description was found in the annotation.
147+
if final_description is None:
148+
if field_info and field_info.description:
149+
final_description = field_info.description
150+
else:
151+
final_description = self.param_descriptions.get(param_name)
152+
153+
# Final fallback if no description was found anywhere
154+
if final_description is None:
155+
final_description = f"Parameter {param_name}"
156+
157+
# Create Final FieldInfo
158+
if field_info:
159+
# If a Field was in Annotated, use it as the base
160+
final_field = copy(field_info)
161+
else:
162+
# Otherwise, create a new default Field
163+
final_field = Field()
141164

142-
# Annotated but no useful metadata
143-
return actual_type, None
165+
final_field.description = final_description
144166

145-
# Not annotated
146-
return annotation, None
167+
# Override default from function signature if present
168+
if param_default is not ...:
169+
final_field.default = param_default
170+
171+
return actual_type, final_field
147172

148173
def _validate_signature(self) -> None:
149174
"""Verify that ToolContext is used correctly in the function signature."""
@@ -173,51 +198,20 @@ def _create_input_model(self) -> Type[BaseModel]:
173198
field_definitions: dict[str, Any] = {}
174199

175200
for name, param in self.signature.parameters.items():
176-
# Skip parameters that will be automatically injected
177201
if self._is_special_parameter(name):
178202
continue
179203

180-
# Get parameter type hint and any Annotated metadata
181204
param_type = self.type_hints.get(name, Any)
182-
actual_type, annotated_meta = self._extract_annotated_metadata(param_type)
183-
184-
# Determine parameter default value
185205
default = ... if param.default is inspect.Parameter.empty else param.default
186206

187-
# Determine description (priority: Annotated > docstring > generic)
188-
description: str
189-
if isinstance(annotated_meta, str):
190-
description = annotated_meta
191-
elif isinstance(annotated_meta, FieldInfo) and annotated_meta.description is not None:
192-
description = annotated_meta.description
193-
elif name in self.param_descriptions:
194-
description = self.param_descriptions[name]
195-
else:
196-
description = f"Parameter {name}"
197-
198-
# Create Field definition for create_model
199-
if isinstance(annotated_meta, FieldInfo):
200-
# Create a defensive copy to avoid mutating a shared FieldInfo instance.
201-
field_info_copy = copy(annotated_meta)
202-
field_info_copy.description = description
203-
204-
# Update default if specified in the function signature.
205-
if default is not ...:
206-
field_info_copy.default = default
207-
208-
field_definitions[name] = (actual_type, field_info_copy)
209-
else:
210-
# For non-FieldInfo metadata, create a new Field.
211-
field_definitions[name] = (actual_type, Field(default=default, description=description))
207+
actual_type, field_info = self._extract_annotated_metadata(param_type, name, default)
208+
field_definitions[name] = (actual_type, field_info)
212209

213-
# Create model name based on function name
214210
model_name = f"{self.func.__name__.capitalize()}Tool"
215211

216-
# Create and return the model
217212
if field_definitions:
218213
return create_model(model_name, **field_definitions)
219214
else:
220-
# Handle case with no parameters
221215
return create_model(model_name)
222216

223217
def extract_metadata(self) -> ToolSpec:

tests/strands/tools/test_decorator.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1488,7 +1488,7 @@ def test_tool_decorator_annotated_pydantic_field_constraints():
14881488

14891489
@strands.tool
14901490
def field_annotated_tool(
1491-
email: Annotated[str, Field(description="User's email address", pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$")],
1491+
email: Annotated[str, Field(description="User's email address", pattern=r"^[\w\.-]+@[\w\.-]+\\.\w+$")],
14921492
score: Annotated[int, Field(description="Score between 0-100", ge=0, le=100)] = 50,
14931493
) -> str:
14941494
"""Tool with Pydantic Field annotations."""
@@ -1680,3 +1680,19 @@ def validation_tool(age: Annotated[int, "User age"]) -> str:
16801680

16811681
result = (await alist(stream))[-1]
16821682
assert result["tool_result"]["status"] == "error"
1683+
1684+
1685+
def test_tool_decorator_annotated_field_with_inner_default():
1686+
"""Test that a default value in an Annotated Field is respected."""
1687+
1688+
@strands.tool
1689+
def inner_default_tool(name: str, level: Annotated[int, Field(description="A level value", default=10)]) -> str:
1690+
return f"{name} is at level {level}"
1691+
1692+
spec = inner_default_tool.tool_spec
1693+
schema = spec["inputSchema"]["json"]
1694+
1695+
# 'level' should not be required because its Field has a default
1696+
assert "name" in schema["required"]
1697+
assert "level" not in schema["required"]
1698+
assert schema["properties"]["level"]["default"] == 10

0 commit comments

Comments
 (0)