Skip to content

Commit 33fd2bd

Browse files
committed
Merge branch 'main' of https://github.com/strands-agents/sdk-python into interrupt_state
2 parents 1b46b12 + 2b0c6e6 commit 33fd2bd

File tree

20 files changed

+691
-87
lines changed

20 files changed

+691
-87
lines changed

src/strands/_async.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Private async execution utilities."""
22

33
import asyncio
4+
import contextvars
45
from concurrent.futures import ThreadPoolExecutor
56
from typing import Awaitable, Callable, TypeVar
67

@@ -27,5 +28,6 @@ def execute() -> T:
2728
return asyncio.run(execute_async())
2829

2930
with ThreadPoolExecutor() as executor:
30-
future = executor.submit(execute)
31+
context = contextvars.copy_context()
32+
future = executor.submit(context.run, execute)
3133
return future.result()

src/strands/agent/agent.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -170,22 +170,21 @@ async def acall() -> ToolResult:
170170
self._agent._interrupt_state.deactivate()
171171
raise RuntimeError("cannot raise interrupt in direct tool call")
172172

173-
return tool_results[0]
173+
tool_result = tool_results[0]
174174

175-
tool_result = run_async(acall)
175+
if record_direct_tool_call is not None:
176+
should_record_direct_tool_call = record_direct_tool_call
177+
else:
178+
should_record_direct_tool_call = self._agent.record_direct_tool_call
176179

177-
if record_direct_tool_call is not None:
178-
should_record_direct_tool_call = record_direct_tool_call
179-
else:
180-
should_record_direct_tool_call = self._agent.record_direct_tool_call
180+
if should_record_direct_tool_call:
181+
# Create a record of this tool execution in the message history
182+
await self._agent._record_tool_execution(tool_use, tool_result, user_message_override)
181183

182-
if should_record_direct_tool_call:
183-
# Create a record of this tool execution in the message history
184-
self._agent._record_tool_execution(tool_use, tool_result, user_message_override)
184+
return tool_result
185185

186-
# Apply window management
186+
tool_result = run_async(acall)
187187
self._agent.conversation_manager.apply_management(self._agent)
188-
189188
return tool_result
190189

191190
return caller
@@ -533,15 +532,15 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
533532
category=DeprecationWarning,
534533
stacklevel=2,
535534
)
536-
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
535+
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self))
537536
with self.tracer.tracer.start_as_current_span(
538537
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
539538
) as structured_output_span:
540539
try:
541540
if not self.messages and not prompt:
542541
raise ValueError("No conversation history or prompt provided")
543542

544-
temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt)
543+
temp_messages: Messages = self.messages + await self._convert_prompt_to_messages(prompt)
545544

546545
structured_output_span.set_attributes(
547546
{
@@ -574,7 +573,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
574573
return event["output"]
575574

576575
finally:
577-
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
576+
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self))
578577

579578
def cleanup(self) -> None:
580579
"""Clean up resources used by the agent.
@@ -657,7 +656,7 @@ async def stream_async(
657656
callback_handler = kwargs.get("callback_handler", self.callback_handler)
658657

659658
# Process input and get message to add (if any)
660-
messages = self._convert_prompt_to_messages(prompt)
659+
messages = await self._convert_prompt_to_messages(prompt)
661660

662661
self.trace_span = self._start_agent_trace_span(messages)
663662

@@ -699,13 +698,13 @@ async def _run_loop(
699698
Yields:
700699
Events from the event loop cycle.
701700
"""
702-
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
701+
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self))
703702

704703
try:
705704
yield InitEventLoopEvent()
706705

707706
for message in messages:
708-
self._append_message(message)
707+
await self._append_message(message)
709708

710709
structured_output_context = StructuredOutputContext(
711710
structured_output_model or self._default_structured_output_model
@@ -731,7 +730,7 @@ async def _run_loop(
731730

732731
finally:
733732
self.conversation_manager.apply_management(self)
734-
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
733+
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self))
735734

736735
async def _execute_event_loop_cycle(
737736
self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None
@@ -780,7 +779,7 @@ async def _execute_event_loop_cycle(
780779
if structured_output_context:
781780
structured_output_context.cleanup(self.tool_registry)
782781

783-
def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
782+
async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
784783
if self._interrupt_state.activated:
785784
return []
786785

@@ -795,7 +794,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
795794
tool_use_ids = [
796795
content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content
797796
]
798-
self._append_message(
797+
await self._append_message(
799798
{
800799
"role": "user",
801800
"content": generate_missing_tool_result_content(tool_use_ids),
@@ -826,7 +825,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
826825
raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.")
827826
return messages
828827

829-
def _record_tool_execution(
828+
async def _record_tool_execution(
830829
self,
831830
tool: ToolUse,
832831
tool_result: ToolResult,
@@ -886,10 +885,10 @@ def _record_tool_execution(
886885
}
887886

888887
# Add to message history
889-
self._append_message(user_msg)
890-
self._append_message(tool_use_msg)
891-
self._append_message(tool_result_msg)
892-
self._append_message(assistant_msg)
888+
await self._append_message(user_msg)
889+
await self._append_message(tool_use_msg)
890+
await self._append_message(tool_result_msg)
891+
await self._append_message(assistant_msg)
893892

894893
def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
895894
"""Starts a trace span for the agent.
@@ -975,10 +974,10 @@ def _initialize_system_prompt(
975974
else:
976975
return None, None
977976

978-
def _append_message(self, message: Message) -> None:
977+
async def _append_message(self, message: Message) -> None:
979978
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
980979
self.messages.append(message)
981-
self.hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message))
980+
await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message))
982981

983982
def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]:
984983
"""Redact user content preserving toolResult blocks.

src/strands/event_loop/event_loop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ async def event_loop_cycle(
227227
)
228228
structured_output_context.set_forced_mode()
229229
logger.debug("Forcing structured output tool")
230-
agent._append_message(
230+
await agent._append_message(
231231
{"role": "user", "content": [{"text": "You must format the previous response as structured output."}]}
232232
)
233233

@@ -322,7 +322,7 @@ async def _handle_model_execution(
322322
model_id=model_id,
323323
)
324324
with trace_api.use_span(model_invoke_span):
325-
agent.hooks.invoke_callbacks(
325+
await agent.hooks.invoke_callbacks_async(
326326
BeforeModelCallEvent(
327327
agent=agent,
328328
)
@@ -347,7 +347,7 @@ async def _handle_model_execution(
347347
stop_reason, message, usage, metrics = event["stop"]
348348
invocation_state.setdefault("request_state", {})
349349

350-
agent.hooks.invoke_callbacks(
350+
await agent.hooks.invoke_callbacks_async(
351351
AfterModelCallEvent(
352352
agent=agent,
353353
stop_response=AfterModelCallEvent.ModelStopResponse(
@@ -368,7 +368,7 @@ async def _handle_model_execution(
368368
if model_invoke_span:
369369
tracer.end_span_with_error(model_invoke_span, str(e), e)
370370

371-
agent.hooks.invoke_callbacks(
371+
await agent.hooks.invoke_callbacks_async(
372372
AfterModelCallEvent(
373373
agent=agent,
374374
exception=e,
@@ -402,7 +402,7 @@ async def _handle_model_execution(
402402

403403
# Add the response message to the conversation
404404
agent.messages.append(message)
405-
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
405+
await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=message))
406406

407407
# Update metrics
408408
agent.event_loop_metrics.update_usage(usage)
@@ -507,7 +507,7 @@ async def _handle_tool_execution(
507507
}
508508

509509
agent.messages.append(tool_result_message)
510-
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
510+
await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=tool_result_message))
511511

512512
yield ToolResultMessageEvent(message=tool_result_message)
513513

src/strands/hooks/registry.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
via hook provider objects.
88
"""
99

10+
import inspect
1011
import logging
1112
from dataclasses import dataclass
12-
from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar
13+
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar
1314

1415
from ..interrupt import Interrupt, InterruptException
1516

@@ -122,10 +123,15 @@ class HookCallback(Protocol, Generic[TEvent]):
122123
```python
123124
def my_callback(event: StartRequestEvent) -> None:
124125
print(f"Request started for agent: {event.agent.name}")
126+
127+
# Or
128+
129+
async def my_callback(event: StartRequestEvent) -> None:
130+
# await an async operation
125131
```
126132
"""
127133

128-
def __call__(self, event: TEvent) -> None:
134+
def __call__(self, event: TEvent) -> None | Awaitable[None]:
129135
"""Handle a hook event.
130136
131137
Args:
@@ -164,6 +170,10 @@ def my_handler(event: StartRequestEvent):
164170
registry.add_callback(StartRequestEvent, my_handler)
165171
```
166172
"""
173+
# Related issue: https://github.com/strands-agents/sdk-python/issues/330
174+
if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback):
175+
raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback")
176+
167177
callbacks = self._registered_callbacks.setdefault(event_type, [])
168178
callbacks.append(callback)
169179

@@ -189,6 +199,52 @@ def register_hooks(self, registry: HookRegistry):
189199
"""
190200
hook.register_hooks(self)
191201

202+
async def invoke_callbacks_async(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]:
203+
"""Invoke all registered callbacks for the given event.
204+
205+
This method finds all callbacks registered for the event's type and
206+
invokes them in the appropriate order. For events with should_reverse_callbacks=True,
207+
callbacks are invoked in reverse registration order. Any exceptions raised by callback
208+
functions will propagate to the caller.
209+
210+
Additionally, this method aggregates interrupts raised by the user to instantiate human-in-the-loop workflows.
211+
212+
Args:
213+
event: The event to dispatch to registered callbacks.
214+
215+
Returns:
216+
The event dispatched to registered callbacks and any interrupts raised by the user.
217+
218+
Raises:
219+
ValueError: If interrupt name is used more than once.
220+
221+
Example:
222+
```python
223+
event = StartRequestEvent(agent=my_agent)
224+
await registry.invoke_callbacks_async(event)
225+
```
226+
"""
227+
interrupts: dict[str, Interrupt] = {}
228+
229+
for callback in self.get_callbacks_for(event):
230+
try:
231+
if inspect.iscoroutinefunction(callback):
232+
await callback(event)
233+
else:
234+
callback(event)
235+
236+
except InterruptException as exception:
237+
interrupt = exception.interrupt
238+
if interrupt.name in interrupts:
239+
message = f"interrupt_name=<{interrupt.name}> | interrupt name used more than once"
240+
logger.error(message)
241+
raise ValueError(message) from exception
242+
243+
# Each callback is allowed to raise their own interrupt.
244+
interrupts[interrupt.name] = interrupt
245+
246+
return event, list(interrupts.values())
247+
192248
def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]:
193249
"""Invoke all registered callbacks for the given event.
194250
@@ -206,6 +262,7 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte
206262
The event dispatched to registered callbacks and any interrupts raised by the user.
207263
208264
Raises:
265+
RuntimeError: If at least one callback is async.
209266
ValueError: If interrupt name is used more than once.
210267
211268
Example:
@@ -214,9 +271,13 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte
214271
registry.invoke_callbacks(event)
215272
```
216273
"""
274+
callbacks = list(self.get_callbacks_for(event))
217275
interrupts: dict[str, Interrupt] = {}
218276

219-
for callback in self.get_callbacks_for(event):
277+
if any(inspect.iscoroutinefunction(callback) for callback in callbacks):
278+
raise RuntimeError(f"event=<{event}> | use invoke_callbacks_async to invoke async callback")
279+
280+
for callback in callbacks:
220281
try:
221282
callback(event)
222283
except InterruptException as exception:

src/strands/multiagent/graph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def __init__(
453453
self._resume_from_session = False
454454
self.id = id
455455

456-
self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self))
456+
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
457457

458458
def __call__(
459459
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
@@ -516,7 +516,7 @@ async def stream_async(
516516
if invocation_state is None:
517517
invocation_state = {}
518518

519-
self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state))
519+
await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state))
520520

521521
logger.debug("task=<%s> | starting graph execution", task)
522522

@@ -569,7 +569,7 @@ async def stream_async(
569569
raise
570570
finally:
571571
self.state.execution_time = round((time.time() - start_time) * 1000)
572-
self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self))
572+
await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self))
573573
self._resume_from_session = False
574574
self._resume_next_nodes.clear()
575575

@@ -776,7 +776,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
776776

777777
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
778778
"""Execute a single node and yield TypedEvent objects."""
779-
self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, node.node_id, invocation_state))
779+
await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state))
780780

781781
# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
782782
if self.reset_on_revisit and node in self.state.completed_nodes:
@@ -920,7 +920,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
920920
raise
921921

922922
finally:
923-
self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state))
923+
await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state))
924924

925925
def _accumulate_metrics(self, node_result: NodeResult) -> None:
926926
"""Accumulate metrics from a node result."""

0 commit comments

Comments
 (0)