diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 5c49f4b58..009cf76d7 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -45,6 +45,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: import inspect import logging from typing import ( + Annotated, Any, Callable, Generic, @@ -54,12 +55,14 @@ def my_tool(param1: str, param2: int = 42) -> dict: TypeVar, Union, cast, - get_type_hints, + get_args, + get_origin, overload, ) import docstring_parser from pydantic import BaseModel, Field, create_model +from pydantic.fields import FieldInfo from typing_extensions import override from ..interrupt import InterruptException @@ -97,23 +100,74 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - """ self.func = func self.signature = inspect.signature(func) - self.type_hints = get_type_hints(func) self._context_param = context_param self._validate_signature() - # Parse the docstring with docstring_parser + # Parse the docstring once for all parameters doc_str = inspect.getdoc(func) or "" self.doc = docstring_parser.parse(doc_str) - - # Get parameter descriptions from parsed docstring - self.param_descriptions = { - param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params - } + self.param_descriptions = {param.arg_name: param.description for param in self.doc.params if param.description} # Create a Pydantic model for validation self.input_model = self._create_input_model() + def _extract_annotated_metadata( + self, annotation: Any, param_name: str, param_default: Any + ) -> tuple[Any, FieldInfo]: + """Extracts type and a simple string description from an Annotated type hint. + + Returns: + A tuple of (actual_type, field_info), where field_info is a new, simple + Pydantic FieldInfo instance created from the extracted metadata. + """ + actual_type = annotation + description: str | None = None + + if get_origin(annotation) is Annotated: + args = get_args(annotation) + actual_type = args[0] + + # Look through metadata for a string description or a FieldInfo object + for meta in args[1:]: + if isinstance(meta, str): + description = meta + elif isinstance(meta, FieldInfo): + # --- Future Contributor Note --- + # We are explicitly blocking the use of `pydantic.Field` within `Annotated` + # because of the complexities of Pydantic v2's immutable Core Schema. + # + # Once a Pydantic model's schema is built, its `FieldInfo` objects are + # effectively frozen. Attempts to mutate a `FieldInfo` object after + # creation (e.g., by copying it and setting `.description` or `.default`) + # are unreliable because the underlying Core Schema does not see these changes. + # + # The correct way to support this would be to reliably extract all + # constraints (ge, le, pattern, etc.) from the original FieldInfo and + # rebuild a new one from scratch. However, these constraints are not + # stored as public attributes, making them difficult to inspect reliably. + # + # Deferring this complexity until there is clear demand and a robust + # pattern for inspecting FieldInfo constraints is established. + raise NotImplementedError( + "Using pydantic.Field within Annotated is not yet supported for tool decorators. " + "Please use a simple string for the description, or define constraints in the function's " + "docstring." + ) + + # Determine the final description with a clear priority order + # Priority: 1. Annotated string -> 2. Docstring -> 3. Fallback + final_description = description + if final_description is None: + final_description = self.param_descriptions.get(param_name) + if final_description is None: + final_description = f"Parameter {param_name}" + + # Create FieldInfo object from scratch + final_field = Field(default=param_default, description=final_description) + + return actual_type, final_field + def _validate_signature(self) -> None: """Verify that ToolContext is used correctly in the function signature.""" for param in self.signature.parameters.values(): @@ -142,26 +196,25 @@ def _create_input_model(self) -> Type[BaseModel]: field_definitions: dict[str, Any] = {} for name, param in self.signature.parameters.items(): - # Skip parameters that will be automatically injected if self._is_special_parameter(name): continue - # Get parameter type and default - param_type = self.type_hints.get(name, Any) + # Use param.annotation directly to get the raw type hint. Using get_type_hints() + # can cause inconsistent behavior across Python versions for complex Annotated types. + param_type = param.annotation + if param_type is inspect.Parameter.empty: + param_type = Any + default = ... if param.default is inspect.Parameter.empty else param.default - description = self.param_descriptions.get(name, f"Parameter {name}") - # Create Field with description and default - field_definitions[name] = (param_type, Field(default=default, description=description)) + actual_type, field_info = self._extract_annotated_metadata(param_type, name, default) + field_definitions[name] = (actual_type, field_info) - # Create model name based on function name model_name = f"{self.func.__name__.capitalize()}Tool" - # Create and return the model if field_definitions: return create_model(model_name, **field_definitions) else: - # Handle case with no parameters return create_model(model_name) def extract_metadata(self) -> ToolSpec: diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 25f9bc39e..7733d27e5 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -3,10 +3,11 @@ """ from asyncio import Queue -from typing import Any, AsyncGenerator, Dict, Optional, Union +from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union from unittest.mock import MagicMock import pytest +from pydantic import Field import strands from strands import Agent @@ -1450,3 +1451,214 @@ def test_function_tool_metadata_validate_signature_missing_context_config(): @strands.tool def my_tool(tool_context: ToolContext): pass + + +def test_tool_decorator_annotated_string_description(): + """Test tool decorator with Annotated type hints for descriptions.""" + + @strands.tool + def annotated_tool( + name: Annotated[str, "The user's full name"], + age: Annotated[int, "The user's age in years"], + city: str, # No annotation - should use docstring or generic + ) -> str: + """Tool with annotated parameters. + + Args: + city: The user's city (from docstring) + """ + return f"{name}, {age}, {city}" + + spec = annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check that annotated descriptions are used + assert schema["properties"]["name"]["description"] == "The user's full name" + assert schema["properties"]["age"]["description"] == "The user's age in years" + + # Check that docstring is still used for non-annotated params + assert schema["properties"]["city"]["description"] == "The user's city (from docstring)" + + # Verify all are required + assert set(schema["required"]) == {"name", "age", "city"} + + +def test_tool_decorator_annotated_pydantic_field_constraints(): + """Test that using pydantic.Field in Annotated raises a NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def field_annotated_tool( + email: Annotated[str, Field(description="User's email address", pattern=r"^[\w\.-]+@[\w\.-]+\\.w+$")], + score: Annotated[int, Field(description="Score between 0-100", ge=0, le=100)] = 50, + ) -> str: + """Tool with Pydantic Field annotations.""" + return f"{email}: {score}" + + +def test_tool_decorator_annotated_overrides_docstring(): + """Test that Annotated descriptions override docstring descriptions.""" + + @strands.tool + def override_tool(param: Annotated[str, "Description from annotation"]) -> str: + """Tool with both annotation and docstring. + + Args: + param: Description from docstring (should be overridden) + """ + return param + + spec = override_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Annotated description should win + assert schema["properties"]["param"]["description"] == "Description from annotation" + + +def test_tool_decorator_annotated_optional_type(): + """Test tool with Optional types in Annotated.""" + + @strands.tool + def optional_annotated_tool( + required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None + ) -> str: + """Tool with optional annotated parameter.""" + return f"{required}, {optional}" + + spec = optional_annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check descriptions + assert schema["properties"]["required"]["description"] == "Required parameter" + assert schema["properties"]["optional"]["description"] == "Optional parameter" + + # Check required list + assert "required" in schema["required"] + assert "optional" not in schema["required"] + + +def test_tool_decorator_annotated_complex_types(): + """Test tool with complex types in Annotated.""" + + @strands.tool + def complex_annotated_tool( + tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"] + ) -> str: + """Tool with complex annotated types.""" + return f"Tags: {len(tags)}, Config: {len(config)}" + + spec = complex_annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check descriptions + assert schema["properties"]["tags"]["description"] == "List of tag strings" + assert schema["properties"]["config"]["description"] == "Configuration dictionary" + + # Check types are preserved + assert schema["properties"]["tags"]["type"] == "array" + assert schema["properties"]["config"]["type"] == "object" + + +def test_tool_decorator_annotated_mixed_styles(): + """Test that using pydantic.Field in a mixed-style annotation raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def mixed_tool( + plain: str, + annotated_str: Annotated[str, "String description"], + annotated_field: Annotated[int, Field(description="Field description", ge=0)], + docstring_only: int, + ) -> str: + """Tool with mixed parameter styles. + + Args: + plain: Plain parameter description + docstring_only: Docstring description for this param + """ + return "mixed" + + +@pytest.mark.asyncio +async def test_tool_decorator_annotated_execution(alist): + """Test that annotated tools execute correctly.""" + + @strands.tool + def execution_test(name: Annotated[str, "User name"], count: Annotated[int, "Number of times"] = 1) -> str: + """Test execution with annotations.""" + return f"Hello {name} " * count + + # Test tool use + tool_use = {"toolUseId": "test-id", "input": {"name": "Alice", "count": 2}} + stream = execution_test.stream(tool_use, {}) + + result = (await alist(stream))[-1] + assert result["tool_result"]["status"] == "success" + assert "Hello Alice Hello Alice" in result["tool_result"]["content"][0]["text"] + + # Test direct call + direct_result = execution_test("Bob", 3) + assert direct_result == "Hello Bob Hello Bob Hello Bob " + + +def test_tool_decorator_annotated_no_description_fallback(): + """Test that Annotated with a Field raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def no_desc_annotated( + param: Annotated[str, Field()], # Field without description + ) -> str: + """Tool with Annotated but no description. + + Args: + param: Docstring description + """ + return param + + +def test_tool_decorator_annotated_empty_string_description(): + """Test handling of empty string descriptions in Annotated.""" + + @strands.tool + def empty_desc_tool( + param: Annotated[str, ""], # Empty string description + ) -> str: + """Tool with empty annotation description. + + Args: + param: Docstring description + """ + return param + + spec = empty_desc_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Empty string is still a valid description, should not fall back + assert schema["properties"]["param"]["description"] == "" + + +@pytest.mark.asyncio +async def test_tool_decorator_annotated_validation_error(alist): + """Test that validation works correctly with annotated parameters.""" + + @strands.tool + def validation_tool(age: Annotated[int, "User age"]) -> str: + """Tool for validation testing.""" + return f"Age: {age}" + + # Test with wrong type + tool_use = {"toolUseId": "test-id", "input": {"age": "not an int"}} + stream = validation_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + assert result["tool_result"]["status"] == "error" + + +def test_tool_decorator_annotated_field_with_inner_default(): + """Test that a default value in an Annotated Field raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def inner_default_tool(name: str, level: Annotated[int, Field(description="A level value", default=10)]) -> str: + return f"{name} is at level {level}"