diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index 771030619..8bdb9ff59 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -1,9 +1,9 @@ -from unittest.mock import patch from uuid import uuid4 import pytest from strands import Agent, tool +from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import ( AfterInvocationEvent, AfterModelCallEvent, @@ -13,7 +13,6 @@ BeforeToolCallEvent, MessageAddedEvent, ) -from strands.multiagent.base import Status from strands.multiagent.swarm import Swarm from strands.session.file_session_manager import FileSessionManager from strands.types.content import ContentBlock @@ -82,6 +81,38 @@ def writer_agent(hook_provider): ) +@pytest.fixture +def exit_hook(): + class ExitHook: + def __init__(self): + self.should_exit = True + + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.exit_before_analyst) + + def exit_before_analyst(self, event): + if event.node_id == "analyst" and self.should_exit: + raise SystemExit("Controlled exit before analyst") + + return ExitHook() + + +@pytest.fixture +def verify_hook(): + class VerifyHook: + def __init__(self): + self.first_node = None + + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.capture_first_node) + + def capture_first_node(self, event): + if self.first_node is None: + self.first_node = event.node_id + + return VerifyHook() + + def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent, hook_provider): """Test swarm execution with string input.""" # Create the swarm @@ -326,53 +357,43 @@ async def test_swarm_get_agent_results_flattening(): assert agent_results[0].message is not None -@pytest.mark.asyncio -async def test_swarm_interrupt_and_resume(researcher_agent, analyst_agent, writer_agent): - """Test swarm interruption after analyst_agent and resume functionality.""" - session_id = str(uuid4()) - - # Create session manager - session_manager = FileSessionManager(session_id=session_id) - - # Create swarm with session manager - swarm = Swarm([researcher_agent, analyst_agent, writer_agent], session_manager=session_manager) - - # Mock analyst_agent's _invoke method to fail - async def failing_invoke(*args, **kwargs): - raise Exception("Simulated failure in analyst") - yield # This line is never reached, but makes it an async generator - - with patch.object(analyst_agent, "stream_async", side_effect=failing_invoke): - # First execution - should fail at analyst - result = await swarm.invoke_async("Research AI trends and create a brief report") - try: - assert result.status == Status.FAILED - except Exception as e: - assert "Simulated failure in analyst" in str(e) - - # Verify partial execution was persisted - persisted_state = session_manager.read_multi_agent(session_id, swarm.id) - assert persisted_state is not None - assert persisted_state["type"] == "swarm" - assert persisted_state["status"] == "failed" - assert len(persisted_state["node_history"]) == 1 # At least researcher executed +def test_swarm_resume_from_executing_state(tmpdir, exit_hook, verify_hook): + """Test swarm resuming from EXECUTING state using BeforeNodeCallEvent hook.""" + session_id = f"swarm_resume_{uuid4()}" - # Track execution count before resume - initial_execution_count = len(persisted_state["node_history"]) + # First execution - exit before second node + session_manager = FileSessionManager(session_id=session_id, storage_dir=tmpdir) + researcher = Agent(name="researcher", system_prompt="you are a researcher.") + analyst = Agent(name="analyst", system_prompt="you are an analyst.") + writer = Agent(name="writer", system_prompt="you are a writer.") - # Execute swarm again - should automatically resume from saved state - result = await swarm.invoke_async("Research AI trends and create a brief report") + swarm = Swarm([researcher, analyst, writer], session_manager=session_manager, hooks=[exit_hook]) - # Verify successful completion - assert result.status == Status.COMPLETED - assert len(result.results) > 0 + try: + swarm("write AI trends and calculate growth in 100 words") + except SystemExit as e: + assert "Controlled exit before analyst" in str(e) - assert len(result.node_history) >= initial_execution_count + 1 + # Verify state was persisted with EXECUTING status and next node + persisted_state = session_manager.read_multi_agent(session_id, swarm.id) + assert persisted_state["status"] == "executing" + assert len(persisted_state["node_history"]) == 1 + assert persisted_state["node_history"][0] == "researcher" + assert persisted_state["next_nodes_to_execute"] == ["analyst"] - node_names = [node.node_id for node in result.node_history] - assert "researcher" in node_names - # Either analyst or writer (or both) should have executed to complete the task - assert "analyst" in node_names or "writer" in node_names + persisted_state = session_manager.read_multi_agent(session_id, swarm.id) + print(f"Saved session state: {persisted_state}") - # Clean up - session_manager.delete_session(session_id) + # Create fresh agent instances for the second swarm to avoid tool conflicts + exit_hook.should_exit = False + researcher2 = Agent(name="researcher", system_prompt="you are a researcher.") + analyst2 = Agent(name="analyst", system_prompt="you are an analyst.") + writer2 = Agent(name="writer", system_prompt="you are a writer.") + new_swarm = Swarm([researcher2, analyst2, writer2], session_manager=session_manager, hooks=[verify_hook]) + result = new_swarm("write AI trends and calculate growth in 100 words") + + # Verify swarm behavior - should resume from analyst, not restart + assert result.status.value == "completed" + assert verify_hook.first_node == "analyst" + node_ids = [n.node_id for n in result.node_history] + assert "analyst" in node_ids