@@ -46,7 +46,6 @@ def my_tool(param1: str, param2: int = 42) -> dict:
4646from typing import (
4747 Any ,
4848 Callable ,
49- Dict ,
5049 Generic ,
5150 Optional ,
5251 ParamSpec ,
@@ -62,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6261from pydantic import BaseModel , Field , create_model
6362from typing_extensions import override
6463
65- from ..types .tools import AgentTool , JSONSchema , ToolResult , ToolSpec , ToolUse
64+ from ..types .tools import AgentTool , JSONSchema , ToolGenerator , ToolResult , ToolSpec , ToolUse
6665
6766logger = logging .getLogger (__name__ )
6867
@@ -119,7 +118,7 @@ def _create_input_model(self) -> Type[BaseModel]:
119118 Returns:
120119 A Pydantic BaseModel class customized for the function's parameters.
121120 """
122- field_definitions : Dict [str , Any ] = {}
121+ field_definitions : dict [str , Any ] = {}
123122
124123 for name , param in self .signature .parameters .items ():
125124 # Skip special parameters
@@ -179,7 +178,7 @@ def extract_metadata(self) -> ToolSpec:
179178
180179 return tool_spec
181180
182- def _clean_pydantic_schema (self , schema : Dict [str , Any ]) -> None :
181+ def _clean_pydantic_schema (self , schema : dict [str , Any ]) -> None :
183182 """Clean up Pydantic schema to match Strands' expected format.
184183
185184 Pydantic's JSON schema output includes several elements that aren't needed for Strands Agent tools and could
@@ -227,7 +226,7 @@ def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None:
227226 if key in prop_schema :
228227 del prop_schema [key ]
229228
230- def validate_input (self , input_data : Dict [str , Any ]) -> Dict [str , Any ]:
229+ def validate_input (self , input_data : dict [str , Any ]) -> dict [str , Any ]:
231230 """Validate input data using the Pydantic model.
232231
233232 This method ensures that the input data meets the expected schema before it's passed to the actual function. It
@@ -270,32 +269,32 @@ class DecoratedFunctionTool(AgentTool, Generic[P, R]):
270269
271270 _tool_name : str
272271 _tool_spec : ToolSpec
272+ _tool_func : Callable [P , R ]
273273 _metadata : FunctionToolMetadata
274- original_function : Callable [P , R ]
275274
276275 def __init__ (
277276 self ,
278- function : Callable [P , R ],
279277 tool_name : str ,
280278 tool_spec : ToolSpec ,
279+ tool_func : Callable [P , R ],
281280 metadata : FunctionToolMetadata ,
282281 ):
283282 """Initialize the decorated function tool.
284283
285284 Args:
286- function: The original function being decorated.
287285 tool_name: The name to use for the tool (usually the function name).
288286 tool_spec: The tool specification containing metadata for Agent integration.
287+ tool_func: The original function being decorated.
289288 metadata: The FunctionToolMetadata object with extracted function information.
290289 """
291290 super ().__init__ ()
292291
293- self .original_function = function
292+ self ._tool_name = tool_name
294293 self ._tool_spec = tool_spec
294+ self ._tool_func = tool_func
295295 self ._metadata = metadata
296- self ._tool_name = tool_name
297296
298- functools .update_wrapper (wrapper = self , wrapped = self .original_function )
297+ functools .update_wrapper (wrapper = self , wrapped = self ._tool_func )
299298
300299 def __get__ (self , instance : Any , obj_type : Optional [Type ] = None ) -> "DecoratedFunctionTool[P, R]" :
301300 """Descriptor protocol implementation for proper method binding.
@@ -323,12 +322,10 @@ def my_tool():
323322 tool = instance.my_tool
324323 ```
325324 """
326- if instance is not None and not inspect .ismethod (self .original_function ):
325+ if instance is not None and not inspect .ismethod (self ._tool_func ):
327326 # Create a bound method
328- new_callback = self .original_function .__get__ (instance , instance .__class__ )
329- return DecoratedFunctionTool (
330- function = new_callback , tool_name = self .tool_name , tool_spec = self .tool_spec , metadata = self ._metadata
331- )
327+ tool_func = self ._tool_func .__get__ (instance , instance .__class__ )
328+ return DecoratedFunctionTool (self ._tool_name , self ._tool_spec , tool_func , self ._metadata )
332329
333330 return self
334331
@@ -360,7 +357,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
360357
361358 return cast (R , self .invoke (tool_use , ** kwargs ))
362359
363- return self .original_function (* args , ** kwargs )
360+ return self ._tool_func (* args , ** kwargs )
364361
365362 @property
366363 def tool_name (self ) -> str :
@@ -389,10 +386,11 @@ def tool_type(self) -> str:
389386 """
390387 return "function"
391388
392- def invoke (self , tool : ToolUse , * args : Any , ** kwargs : dict [str , Any ]) -> ToolResult :
393- """Invoke the tool with a tool use specification.
389+ @override
390+ def stream (self , tool_use : ToolUse , * args : Any , ** kwargs : dict [str , Any ]) -> ToolGenerator :
391+ """Stream the tool with a tool use specification.
394392
395- This method handles tool use invocations from a Strands Agent. It validates the input,
393+ This method handles tool use streams from a Strands Agent. It validates the input,
396394 calls the function, and formats the result according to the expected tool result format.
397395
398396 Key operations:
@@ -404,15 +402,17 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
404402 5. Handle and format any errors that occur
405403
406404 Args:
407- tool : The tool use specification from the Agent.
405+ tool_use : The tool use specification from the Agent.
408406 *args: Additional positional arguments (not typically used).
409407 **kwargs: Additional keyword arguments, may include 'agent' reference.
410408
409+ Yields:
410+ Events of the tool stream.
411+
411412 Returns:
412413 A standardized tool result dictionary with status and content.
413414 """
414415 # This is a tool use call - process accordingly
415- tool_use = tool
416416 tool_use_id = tool_use .get ("toolUseId" , "unknown" )
417417 tool_input = tool_use .get ("input" , {})
418418
@@ -424,8 +424,9 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
424424 if "agent" in kwargs and "agent" in self ._metadata .signature .parameters :
425425 validated_input ["agent" ] = kwargs .get ("agent" )
426426
427- # We get "too few arguments here" but because that's because fof the way we're calling it
428- result = self .original_function (** validated_input ) # type: ignore
427+ result = self ._tool_func (** validated_input ) # type: ignore # "Too few arguments" expected
428+ if inspect .isgenerator (result ):
429+ result = yield from result
429430
430431 # FORMAT THE RESULT for Strands Agent
431432 if isinstance (result , dict ) and "status" in result and "content" in result :
@@ -476,7 +477,7 @@ def get_display_properties(self) -> dict[str, str]:
476477 Function properties (e.g., function name).
477478 """
478479 properties = super ().get_display_properties ()
479- properties ["Function" ] = self .original_function .__name__
480+ properties ["Function" ] = self ._tool_func .__name__
480481 return properties
481482
482483
@@ -573,7 +574,7 @@ def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
573574 if not isinstance (tool_name , str ):
574575 raise ValueError (f"Tool name must be a string, got { type (tool_name )} " )
575576
576- return DecoratedFunctionTool (function = f , tool_name = tool_name , tool_spec = tool_spec , metadata = tool_meta )
577+ return DecoratedFunctionTool (tool_name , tool_spec , f , tool_meta )
577578
578579 # Handle both @tool and @tool() syntax
579580 if func is None :
0 commit comments