From 1e58e59e02b5087ec69db1f5a300d1893ed9ce3c Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 6 Nov 2025 21:29:26 -0500 Subject: [PATCH] swarm - switch to handoff node only after current node stops --- src/strands/multiagent/swarm.py | 41 ++++++++++++-------------- src/strands/session/session_manager.py | 4 +-- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index accd56463..833cd240a 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -156,6 +156,7 @@ class SwarmState: # Total metrics across all agents accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) execution_time: int = 0 # Total execution time in milliseconds + handoff_node: SwarmNode | None = None # The agent to execute next handoff_message: str | None = None # Message passed during agent handoff def should_continue( @@ -537,7 +538,7 @@ def handoff_to_agent(agent_name: str, message: str, context: dict[str, Any] | No # Execute handoff swarm_ref._handle_handoff(target_node, message, context) - return {"status": "success", "content": [{"text": f"Handed off to {agent_name}: {message}"}]} + return {"status": "success", "content": [{"text": f"Handing off to {agent_name}: {message}"}]} except Exception as e: return {"status": "error", "content": [{"text": f"Error in handoff: {str(e)}"}]} @@ -553,21 +554,19 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st ) return - # Update swarm state - previous_agent = cast(SwarmNode, self.state.current_node) - self.state.current_node = target_node + current_node = cast(SwarmNode, self.state.current_node) - # Store handoff message for the target agent + self.state.handoff_node = target_node self.state.handoff_message = message # Store handoff context as shared context if context: for key, value in context.items(): - self.shared_context.add_context(previous_agent, key, value) + self.shared_context.add_context(current_node, key, value) logger.debug( - "from_node=<%s>, to_node=<%s> | handed off from agent to agent", - previous_agent.node_id, + "from_node=<%s>, to_node=<%s> | handing off from agent to agent", + current_node.node_id, target_node.node_id, ) @@ -667,7 +666,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato logger.debug("reason=<%s> | stopping execution", reason) break - # Get current node current_node = self.state.current_node if not current_node or current_node.node_id not in self.nodes: logger.error("node=<%s> | node not found", current_node.node_id if current_node else "None") @@ -680,14 +678,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato len(self.state.node_history) + 1, ) - # Store the current node before execution to detect handoffs - previous_node = current_node - - # Execute node with timeout protection # TODO: Implement cancellation token to stop _execute_node from continuing try: - # Execute with timeout wrapper for async generator streaming self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, current_node.node_id, invocation_state)) + node_stream = self._stream_with_timeout( self._execute_node(current_node, self.state.task, invocation_state), self.node_timeout, @@ -697,28 +691,31 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato yield event self.state.node_history.append(current_node) - - # After self.state add current node, swarm state finish updating, we persist here self.hooks.invoke_callbacks(AfterNodeCallEvent(self, current_node.node_id, invocation_state)) logger.debug("node=<%s> | node execution completed", current_node.node_id) - # Check if handoff occurred during execution - if self.state.current_node is not None and self.state.current_node != previous_node: - # Emit handoff event (single node transition in Swarm) + # Check if handoff requested during execution + if self.state.handoff_node: + previous_node = current_node + current_node = self.state.handoff_node + + self.state.handoff_node = None + self.state.current_node = current_node + handoff_event = MultiAgentHandoffEvent( from_node_ids=[previous_node.node_id], - to_node_ids=[self.state.current_node.node_id], + to_node_ids=[current_node.node_id], message=self.state.handoff_message or "Agent handoff occurred", ) yield handoff_event logger.debug( "from_node=<%s>, to_node=<%s> | handoff detected", previous_node.node_id, - self.state.current_node.node_id, + current_node.node_id, ) + else: - # No handoff occurred, mark swarm as complete logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id) self.state.completion_status = Status.COMPLETED break diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index fb9132828..d4bc72c80 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -6,7 +6,7 @@ from ..experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, - AfterNodeCallEvent, + BeforeNodeCallEvent, MultiAgentInitializedEvent, ) from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent @@ -44,7 +44,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source)) - registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) + registry.add_callback(BeforeNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) @abstractmethod