Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 67 additions & 46 deletions tests_integ/test_multiagent_swarm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from unittest.mock import patch
from uuid import uuid4

import pytest

from strands import Agent, tool
from strands.experimental.hooks.multiagent import BeforeNodeCallEvent
from strands.hooks import (
AfterInvocationEvent,
AfterModelCallEvent,
Expand All @@ -13,7 +13,6 @@
BeforeToolCallEvent,
MessageAddedEvent,
)
from strands.multiagent.base import Status
from strands.multiagent.swarm import Swarm
from strands.session.file_session_manager import FileSessionManager
from strands.types.content import ContentBlock
Expand Down Expand Up @@ -82,6 +81,38 @@ def writer_agent(hook_provider):
)


@pytest.fixture
def exit_hook():
class ExitHook:
def __init__(self):
self.should_exit = True

def register_hooks(self, registry):
registry.add_callback(BeforeNodeCallEvent, self.exit_before_analyst)

def exit_before_analyst(self, event):
if event.node_id == "analyst" and self.should_exit:
raise SystemExit("Controlled exit before analyst")

return ExitHook()


@pytest.fixture
def verify_hook():
class VerifyHook:
def __init__(self):
self.first_node = None

def register_hooks(self, registry):
registry.add_callback(BeforeNodeCallEvent, self.capture_first_node)

def capture_first_node(self, event):
if self.first_node is None:
self.first_node = event.node_id

return VerifyHook()


def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent, hook_provider):
"""Test swarm execution with string input."""
# Create the swarm
Expand Down Expand Up @@ -326,53 +357,43 @@ async def test_swarm_get_agent_results_flattening():
assert agent_results[0].message is not None


@pytest.mark.asyncio
async def test_swarm_interrupt_and_resume(researcher_agent, analyst_agent, writer_agent):
"""Test swarm interruption after analyst_agent and resume functionality."""
session_id = str(uuid4())

# Create session manager
session_manager = FileSessionManager(session_id=session_id)

# Create swarm with session manager
swarm = Swarm([researcher_agent, analyst_agent, writer_agent], session_manager=session_manager)

# Mock analyst_agent's _invoke method to fail
async def failing_invoke(*args, **kwargs):
raise Exception("Simulated failure in analyst")
yield # This line is never reached, but makes it an async generator

with patch.object(analyst_agent, "stream_async", side_effect=failing_invoke):
# First execution - should fail at analyst
result = await swarm.invoke_async("Research AI trends and create a brief report")
try:
assert result.status == Status.FAILED
except Exception as e:
assert "Simulated failure in analyst" in str(e)

# Verify partial execution was persisted
persisted_state = session_manager.read_multi_agent(session_id, swarm.id)
assert persisted_state is not None
assert persisted_state["type"] == "swarm"
assert persisted_state["status"] == "failed"
assert len(persisted_state["node_history"]) == 1 # At least researcher executed
def test_swarm_resume_from_executing_state(tmpdir, exit_hook, verify_hook):
"""Test swarm resuming from EXECUTING state using BeforeNodeCallEvent hook."""
session_id = f"swarm_resume_{uuid4()}"

# Track execution count before resume
initial_execution_count = len(persisted_state["node_history"])
# First execution - exit before second node
session_manager = FileSessionManager(session_id=session_id, storage_dir=tmpdir)
researcher = Agent(name="researcher", system_prompt="you are a researcher.")
analyst = Agent(name="analyst", system_prompt="you are an analyst.")
writer = Agent(name="writer", system_prompt="you are a writer.")

# Execute swarm again - should automatically resume from saved state
result = await swarm.invoke_async("Research AI trends and create a brief report")
swarm = Swarm([researcher, analyst, writer], session_manager=session_manager, hooks=[exit_hook])

# Verify successful completion
assert result.status == Status.COMPLETED
assert len(result.results) > 0
try:
swarm("write AI trends and calculate growth in 100 words")
except SystemExit as e:
assert "Controlled exit before analyst" in str(e)

assert len(result.node_history) >= initial_execution_count + 1
# Verify state was persisted with EXECUTING status and next node
persisted_state = session_manager.read_multi_agent(session_id, swarm.id)
assert persisted_state["status"] == "executing"
assert len(persisted_state["node_history"]) == 1
assert persisted_state["node_history"][0] == "researcher"
assert persisted_state["next_nodes_to_execute"] == ["analyst"]

node_names = [node.node_id for node in result.node_history]
assert "researcher" in node_names
# Either analyst or writer (or both) should have executed to complete the task
assert "analyst" in node_names or "writer" in node_names
persisted_state = session_manager.read_multi_agent(session_id, swarm.id)
print(f"Saved session state: {persisted_state}")

# Clean up
session_manager.delete_session(session_id)
# Create fresh agent instances for the second swarm to avoid tool conflicts
exit_hook.should_exit = False
researcher2 = Agent(name="researcher", system_prompt="you are a researcher.")
analyst2 = Agent(name="analyst", system_prompt="you are an analyst.")
writer2 = Agent(name="writer", system_prompt="you are a writer.")
new_swarm = Swarm([researcher2, analyst2, writer2], session_manager=session_manager, hooks=[verify_hook])
result = new_swarm("write AI trends and calculate growth in 100 words")

# Verify swarm behavior - should resume from analyst, not restart
assert result.status.value == "completed"
assert verify_hook.first_node == "analyst"
node_ids = [n.node_id for n in result.node_history]
assert "analyst" in node_ids
Loading