Skip to content

Commit 718bf5f

Browse files
zastrowmWorkshop Participant
authored andcommitted
feat: Add hooks for before/after tool calls + allow hooks to update values (#352)
Add the ability to intercept/modify tool calls by implementing support for BeforeToolInvocationEvent & AfterToolInvocationEvent hooks
1 parent 37b8a67 commit 718bf5f

File tree

8 files changed

+631
-53
lines changed

8 files changed

+631
-53
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import uuid
1414
from typing import TYPE_CHECKING, Any, AsyncGenerator
1515

16+
from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
17+
from ..experimental.hooks.registry import get_registry
1618
from ..telemetry.metrics import Trace
1719
from ..telemetry.tracer import get_tracer
1820
from ..tools.executor import run_tools, validate_and_prepare_tools
@@ -271,46 +273,97 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
271273
The final tool result or an error response if the tool fails or is not found.
272274
"""
273275
logger.debug("tool_use=<%s> | streaming", tool_use)
274-
tool_use_id = tool_use["toolUseId"]
275276
tool_name = tool_use["name"]
276277

277278
# Get the tool info
278279
tool_info = agent.tool_registry.dynamic_tools.get(tool_name)
279280
tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name)
280281

282+
# Add standard arguments to kwargs for Python tools
283+
kwargs.update(
284+
{
285+
"model": agent.model,
286+
"system_prompt": agent.system_prompt,
287+
"messages": agent.messages,
288+
"tool_config": agent.tool_config,
289+
}
290+
)
291+
292+
before_event = get_registry(agent).invoke_callbacks(
293+
BeforeToolInvocationEvent(
294+
agent=agent,
295+
selected_tool=tool_func,
296+
tool_use=tool_use,
297+
kwargs=kwargs,
298+
)
299+
)
300+
281301
try:
302+
selected_tool = before_event.selected_tool
303+
tool_use = before_event.tool_use
304+
282305
# Check if tool exists
283-
if not tool_func:
284-
logger.error(
285-
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
286-
tool_name,
287-
list(agent.tool_registry.registry.keys()),
288-
)
289-
return {
290-
"toolUseId": tool_use_id,
306+
if not selected_tool:
307+
if tool_func == selected_tool:
308+
logger.error(
309+
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
310+
tool_name,
311+
list(agent.tool_registry.registry.keys()),
312+
)
313+
else:
314+
logger.debug(
315+
"tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
316+
tool_name,
317+
str(tool_use.get("toolUseId")),
318+
)
319+
320+
result: ToolResult = {
321+
"toolUseId": str(tool_use.get("toolUseId")),
291322
"status": "error",
292323
"content": [{"text": f"Unknown tool: {tool_name}"}],
293324
}
294-
# Add standard arguments to kwargs for Python tools
295-
kwargs.update(
296-
{
297-
"model": agent.model,
298-
"system_prompt": agent.system_prompt,
299-
"messages": agent.messages,
300-
"tool_config": agent.tool_config,
301-
}
302-
)
325+
# for every Before event call, we need to have an AfterEvent call
326+
after_event = get_registry(agent).invoke_callbacks(
327+
AfterToolInvocationEvent(
328+
agent=agent,
329+
selected_tool=selected_tool,
330+
tool_use=tool_use,
331+
kwargs=kwargs,
332+
result=result,
333+
)
334+
)
335+
return after_event.result
303336

304-
result = yield from tool_func.stream(tool_use, **kwargs)
305-
return result
337+
result = yield from selected_tool.stream(tool_use, **kwargs)
338+
after_event = get_registry(agent).invoke_callbacks(
339+
AfterToolInvocationEvent(
340+
agent=agent,
341+
selected_tool=selected_tool,
342+
tool_use=tool_use,
343+
kwargs=kwargs,
344+
result=result,
345+
)
346+
)
347+
return after_event.result
306348

307349
except Exception as e:
308350
logger.exception("tool_name=<%s> | failed to process tool", tool_name)
309-
return {
310-
"toolUseId": tool_use_id,
351+
error_result: ToolResult = {
352+
"toolUseId": str(tool_use.get("toolUseId")),
311353
"status": "error",
312354
"content": [{"text": f"Error: {str(e)}"}],
313355
}
356+
after_event = get_registry(agent).invoke_callbacks(
357+
AfterToolInvocationEvent(
358+
agent=agent,
359+
selected_tool=selected_tool,
360+
tool_use=tool_use,
361+
kwargs=kwargs,
362+
result=error_result,
363+
exception=e,
364+
)
365+
)
366+
return after_event.result
314367

315368

316369
async def _handle_tool_execution(

src/strands/experimental/hooks/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,21 @@ def log_end(self, event: EndRequestEvent) -> None:
2929
type-safe system that supports multiple subscribers per event type.
3030
"""
3131

32-
from .events import AgentInitializedEvent, EndRequestEvent, StartRequestEvent
32+
from .events import (
33+
AfterToolInvocationEvent,
34+
AgentInitializedEvent,
35+
BeforeToolInvocationEvent,
36+
EndRequestEvent,
37+
StartRequestEvent,
38+
)
3339
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry
3440

3541
__all__ = [
3642
"AgentInitializedEvent",
3743
"StartRequestEvent",
3844
"EndRequestEvent",
45+
"BeforeToolInvocationEvent",
46+
"AfterToolInvocationEvent",
3947
"HookEvent",
4048
"HookProvider",
4149
"HookCallback",

src/strands/experimental/hooks/events.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
"""
55

66
from dataclasses import dataclass
7+
from typing import Any, Optional
78

9+
from ...types.tools import AgentTool, ToolResult, ToolUse
810
from .registry import HookEvent
911

1012

@@ -56,9 +58,63 @@ class EndRequestEvent(HookEvent):
5658

5759
@property
5860
def should_reverse_callbacks(self) -> bool:
59-
"""Return True to invoke callbacks in reverse order for proper cleanup.
61+
"""True to invoke callbacks in reverse order."""
62+
return True
63+
64+
65+
@dataclass
66+
class BeforeToolInvocationEvent(HookEvent):
67+
"""Event triggered before a tool is invoked.
68+
69+
This event is fired just before the agent executes a tool, allowing hook
70+
providers to inspect, modify, or replace the tool that will be executed.
71+
The selected_tool can be modified by hook callbacks to change which tool
72+
gets executed.
73+
74+
Attributes:
75+
selected_tool: The tool that will be invoked. Can be modified by hooks
76+
to change which tool gets executed. This may be None if tool lookup failed.
77+
tool_use: The tool parameters that will be passed to selected_tool.
78+
kwargs: Keyword arguments that will be passed to the tool.
79+
"""
80+
81+
selected_tool: Optional[AgentTool]
82+
tool_use: ToolUse
83+
kwargs: dict[str, Any]
84+
85+
def _can_write(self, name: str) -> bool:
86+
return name in ["selected_tool", "tool_use"]
87+
88+
89+
@dataclass
90+
class AfterToolInvocationEvent(HookEvent):
91+
"""Event triggered after a tool invocation completes.
6092
61-
Returns:
62-
True, indicating callbacks should be invoked in reverse order.
63-
"""
93+
This event is fired after the agent has finished executing a tool,
94+
regardless of whether the execution was successful or resulted in an error.
95+
Hook providers can use this event for cleanup, logging, or post-processing.
96+
97+
Note: This event uses reverse callback ordering, meaning callbacks registered
98+
later will be invoked first during cleanup.
99+
100+
Attributes:
101+
selected_tool: The tool that was invoked. It may be None if tool lookup failed.
102+
tool_use: The tool parameters that were passed to the tool invoked.
103+
kwargs: Keyword arguments that were passed to the tool
104+
result: The result of the tool invocation. Either a ToolResult on success
105+
or an Exception if the tool execution failed.
106+
"""
107+
108+
selected_tool: Optional[AgentTool]
109+
tool_use: ToolUse
110+
kwargs: dict[str, Any]
111+
result: ToolResult
112+
exception: Optional[Exception] = None
113+
114+
def _can_write(self, name: str) -> bool:
115+
return name == "result"
116+
117+
@property
118+
def should_reverse_callbacks(self) -> bool:
119+
"""True to invoke callbacks in reverse order."""
64120
return True

src/strands/experimental/hooks/registry.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
from dataclasses import dataclass
11-
from typing import TYPE_CHECKING, Callable, Generator, Generic, Protocol, Type, TypeVar
11+
from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar
1212

1313
if TYPE_CHECKING:
1414
from ...agent import Agent
@@ -34,9 +34,43 @@ def should_reverse_callbacks(self) -> bool:
3434
"""
3535
return False
3636

37+
def _can_write(self, name: str) -> bool:
38+
"""Check if the given property can be written to.
39+
40+
Args:
41+
name: The name of the property to check.
42+
43+
Returns:
44+
True if the property can be written to, False otherwise.
45+
"""
46+
return False
47+
48+
def __post_init__(self) -> None:
49+
"""Disallow writes to non-approved properties."""
50+
# This is needed as otherwise the class can't be initialized at all, so we trigger
51+
# this after class initialization
52+
super().__setattr__("_disallow_writes", True)
53+
54+
def __setattr__(self, name: str, value: Any) -> None:
55+
"""Prevent setting attributes on hook events.
56+
57+
Raises:
58+
AttributeError: Always raised to prevent setting attributes on hook events.
59+
"""
60+
# Allow setting attributes:
61+
# - during init (when __dict__) doesn't exist
62+
# - if the subclass specifically said the property is writable
63+
if not hasattr(self, "_disallow_writes") or self._can_write(name):
64+
return super().__setattr__(name, value)
65+
66+
raise AttributeError(f"Property {name} is not writable")
67+
3768

38-
T = TypeVar("T", bound=Callable)
3969
TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True)
70+
"""Generic for adding callback handlers - contravariant to allow adding handlers which take in base classes."""
71+
72+
TInvokeEvent = TypeVar("TInvokeEvent", bound=HookEvent)
73+
"""Generic for invoking events - non-contravariant to enable returning events."""
4074

4175

4276
class HookProvider(Protocol):
@@ -144,7 +178,7 @@ def register_hooks(self, registry: HookRegistry):
144178
"""
145179
hook.register_hooks(self)
146180

147-
def invoke_callbacks(self, event: TEvent) -> None:
181+
def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent:
148182
"""Invoke all registered callbacks for the given event.
149183
150184
This method finds all callbacks registered for the event's type and
@@ -157,6 +191,9 @@ def invoke_callbacks(self, event: TEvent) -> None:
157191
Raises:
158192
Any exceptions raised by callback functions will propagate to the caller.
159193
194+
Returns:
195+
The event dispatched to registered callbacks.
196+
160197
Example:
161198
```python
162199
event = StartRequestEvent(agent=my_agent)
@@ -166,6 +203,8 @@ def invoke_callbacks(self, event: TEvent) -> None:
166203
for callback in self.get_callbacks_for(event):
167204
callback(event)
168205

206+
return event
207+
169208
def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]:
170209
"""Get callbacks registered for the given event in the appropriate order.
171210
@@ -193,3 +232,18 @@ def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], No
193232
yield from reversed(callbacks)
194233
else:
195234
yield from callbacks
235+
236+
237+
def get_registry(agent: "Agent") -> HookRegistry:
238+
"""*Experimental*: Get the hooks registry for the provided agent.
239+
240+
This function is available while hooks are in experimental preview.
241+
242+
Args:
243+
agent: The agent whose hook registry should be returned.
244+
245+
Returns:
246+
The HookRegistry for the given agent.
247+
248+
"""
249+
return agent._hooks
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections import deque
2-
from typing import Type
1+
from typing import Iterator, Tuple, Type
32

43
from strands.experimental.hooks import HookEvent, HookProvider, HookRegistry
54

@@ -9,12 +8,12 @@ def __init__(self, event_types: list[Type]):
98
self.events_received = []
109
self.events_types = event_types
1110

12-
def get_events(self) -> deque[HookEvent]:
13-
return deque(self.events_received)
11+
def get_events(self) -> Tuple[int, Iterator[HookEvent]]:
12+
return len(self.events_received), iter(self.events_received)
1413

1514
def register_hooks(self, registry: HookRegistry) -> None:
1615
for event_type in self.events_types:
17-
registry.add_callback(event_type, self._add_event)
16+
registry.add_callback(event_type, self.add_event)
1817

19-
def _add_event(self, event: HookEvent) -> None:
18+
def add_event(self, event: HookEvent) -> None:
2019
self.events_received.append(event)

0 commit comments

Comments
 (0)