@@ -61,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6161from pydantic import BaseModel , Field , create_model
6262from typing_extensions import override
6363
64- from ..types .tools import AgentTool , JSONSchema , ToolGenerator , ToolSpec , ToolUse
64+ from ..types .tools import AgentTool , JSONSchema , ToolContext , ToolGenerator , ToolSpec , ToolUse
6565
6666logger = logging .getLogger (__name__ )
6767
@@ -84,16 +84,18 @@ class FunctionToolMetadata:
8484 validate tool usage.
8585 """
8686
87- def __init__ (self , func : Callable [..., Any ]) -> None :
87+ def __init__ (self , func : Callable [..., Any ], context_param : str | None = None ) -> None :
8888 """Initialize with the function to process.
8989
9090 Args:
9191 func: The function to extract metadata from.
9292 Can be a standalone function or a class method.
93+ context_param: Name of the context parameter to inject, if any.
9394 """
9495 self .func = func
9596 self .signature = inspect .signature (func )
9697 self .type_hints = get_type_hints (func )
98+ self ._context_param = context_param
9799
98100 # Parse the docstring with docstring_parser
99101 doc_str = inspect .getdoc (func ) or ""
@@ -113,16 +115,16 @@ def _create_input_model(self) -> Type[BaseModel]:
113115 This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can
114116 validate input data before passing it to the function.
115117
116- Special parameters like 'self', 'cls', and 'agent' are excluded from the model.
118+ Special parameters that can be automatically injected are excluded from the model.
117119
118120 Returns:
119121 A Pydantic BaseModel class customized for the function's parameters.
120122 """
121123 field_definitions : dict [str , Any ] = {}
122124
123125 for name , param in self .signature .parameters .items ():
124- # Skip special parameters
125- if name in ( " self" , "cls" , "agent" ):
126+ # Skip parameters that will be automatically injected
127+ if self . _is_special_parameter ( name ):
126128 continue
127129
128130 # Get parameter type and default
@@ -252,6 +254,49 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
252254 error_msg = str (e )
253255 raise ValueError (f"Validation failed for input parameters: { error_msg } " ) from e
254256
257+ def inject_special_parameters (
258+ self , validated_input : dict [str , Any ], tool_use : ToolUse , invocation_state : dict [str , Any ]
259+ ) -> None :
260+ """Inject special framework-provided parameters into the validated input.
261+
262+ This method automatically provides framework-level context to tools that request it
263+ through their function signature.
264+
265+ Args:
266+ validated_input: The validated input parameters (modified in place).
267+ tool_use: The tool use request containing tool invocation details.
268+ invocation_state: Context for the tool invocation, including agent state.
269+ """
270+ if self ._context_param and self ._context_param in self .signature .parameters :
271+ tool_context = ToolContext (tool_use = tool_use , agent = invocation_state ["agent" ])
272+ validated_input [self ._context_param ] = tool_context
273+
274+ # Inject agent if requested (backward compatibility)
275+ if "agent" in self .signature .parameters and "agent" in invocation_state :
276+ validated_input ["agent" ] = invocation_state ["agent" ]
277+
278+ def _is_special_parameter (self , param_name : str ) -> bool :
279+ """Check if a parameter should be automatically injected by the framework or is a standard Python method param.
280+
281+ Special parameters include:
282+ - Standard Python method parameters: self, cls
283+ - Framework-provided context parameters: agent, and configurable context parameter (defaults to tool_context)
284+
285+ Args:
286+ param_name: The name of the parameter to check.
287+
288+ Returns:
289+ True if the parameter should be excluded from input validation and
290+ handled specially during tool execution.
291+ """
292+ special_params = {"self" , "cls" , "agent" }
293+
294+ # Add context parameter if configured
295+ if self ._context_param :
296+ special_params .add (self ._context_param )
297+
298+ return param_name in special_params
299+
255300
256301P = ParamSpec ("P" ) # Captures all parameters
257302R = TypeVar ("R" ) # Return type
@@ -402,9 +447,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
402447 # Validate input against the Pydantic model
403448 validated_input = self ._metadata .validate_input (tool_input )
404449
405- # Pass along the agent if provided and expected by the function
406- if "agent" in invocation_state and "agent" in self ._metadata .signature .parameters :
407- validated_input ["agent" ] = invocation_state .get ("agent" )
450+ # Inject special framework-provided parameters
451+ self ._metadata .inject_special_parameters (validated_input , tool_use , invocation_state )
408452
409453 # "Too few arguments" expected, hence the type ignore
410454 if inspect .iscoroutinefunction (self ._tool_func ):
@@ -474,6 +518,7 @@ def tool(
474518 description : Optional [str ] = None ,
475519 inputSchema : Optional [JSONSchema ] = None ,
476520 name : Optional [str ] = None ,
521+ context : bool | str = False ,
477522) -> Callable [[Callable [P , R ]], DecoratedFunctionTool [P , R ]]: ...
478523# Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the
479524# call site, but the actual implementation handles that and it's not representable via the type-system
@@ -482,6 +527,7 @@ def tool( # type: ignore
482527 description : Optional [str ] = None ,
483528 inputSchema : Optional [JSONSchema ] = None ,
484529 name : Optional [str ] = None ,
530+ context : bool | str = False ,
485531) -> Union [DecoratedFunctionTool [P , R ], Callable [[Callable [P , R ]], DecoratedFunctionTool [P , R ]]]:
486532 """Decorator that transforms a Python function into a Strands tool.
487533
@@ -507,6 +553,9 @@ def tool( # type: ignore
507553 description: Optional custom description to override the function's docstring.
508554 inputSchema: Optional custom JSON schema to override the automatically generated schema.
509555 name: Optional custom name to override the function's name.
556+ context: When provided, places an object in the designated parameter. If True, the param name
557+ defaults to 'tool_context', or if an override is needed, set context equal to a string to designate
558+ the param name.
510559
511560 Returns:
512561 An AgentTool that also mimics the original function when invoked
@@ -536,15 +585,24 @@ def my_tool(name: str, count: int = 1) -> str:
536585
537586 Example with parameters:
538587 ```python
539- @tool(name="custom_tool", description="A tool with a custom name and description")
540- def my_tool(name: str, count: int = 1) -> str:
541- return f"Processed {name} {count} times"
588+ @tool(name="custom_tool", description="A tool with a custom name and description", context=True)
589+ def my_tool(name: str, count: int = 1, tool_context: ToolContext) -> str:
590+ tool_id = tool_context["tool_use"]["toolUseId"]
591+ return f"Processed {name} {count} times with tool ID {tool_id}"
542592 ```
543593 """
544594
545595 def decorator (f : T ) -> "DecoratedFunctionTool[P, R]" :
596+ # Resolve context parameter name
597+ if isinstance (context , bool ):
598+ context_param = "tool_context" if context else None
599+ else :
600+ context_param = context .strip ()
601+ if not context_param :
602+ raise ValueError ("Context parameter name cannot be empty" )
603+
546604 # Create function tool metadata
547- tool_meta = FunctionToolMetadata (f )
605+ tool_meta = FunctionToolMetadata (f , context_param )
548606 tool_spec = tool_meta .extract_metadata ()
549607 if name is not None :
550608 tool_spec ["name" ] = name
0 commit comments