diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index fa4f7051f..b7633d5e8 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -46,6 +46,7 @@ HookRegistry, MessageAddedEvent, ) +from ..interrupt import _InterruptState from ..models.bedrock import BedrockModel from ..models.model import Model from ..session.session_manager import SessionManager @@ -60,7 +61,6 @@ from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException -from ..types.interrupt import InterruptResponseContent from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -68,7 +68,6 @@ ConversationManager, SlidingWindowConversationManager, ) -from .interrupt import InterruptState from .state import AgentState logger = logging.getLogger(__name__) @@ -352,7 +351,7 @@ def __init__( self.hooks = HookRegistry() - self._interrupt_state = InterruptState() + self._interrupt_state = _InterruptState() # Initialize session management functionality self._session_manager = session_manager @@ -640,7 +639,7 @@ async def stream_async( yield event["data"] ``` """ - self._resume_interrupt(prompt) + self._interrupt_state.resume(prompt) merged_state = {} if kwargs: @@ -683,38 +682,6 @@ async def stream_async( self._end_agent_trace_span(error=e) raise - def _resume_interrupt(self, prompt: AgentInput) -> None: - """Configure the interrupt state if resuming from an interrupt event. - - Args: - prompt: User responses if resuming from interrupt. - - Raises: - TypeError: If in interrupt state but user did not provide responses. - """ - if not self._interrupt_state.activated: - return - - if not isinstance(prompt, list): - raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's") - - invalid_types = [ - content_type for content in prompt for content_type in content if content_type != "interruptResponse" - ] - if invalid_types: - raise TypeError( - f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's" - ) - - for content in cast(list[InterruptResponseContent], prompt): - interrupt_id = content["interruptResponse"]["interruptId"] - interrupt_response = content["interruptResponse"]["response"] - - if interrupt_id not in self._interrupt_state.interrupts: - raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") - - self._interrupt_state.interrupts[interrupt_id].response = interrupt_response - async def _run_loop( self, messages: Messages, diff --git a/src/strands/agent/interrupt.py b/src/strands/agent/interrupt.py deleted file mode 100644 index 3cec1541b..000000000 --- a/src/strands/agent/interrupt.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Track the state of interrupt events raised by the user for human-in-the-loop workflows.""" - -from dataclasses import asdict, dataclass, field -from typing import Any - -from ..interrupt import Interrupt - - -@dataclass -class InterruptState: - """Track the state of interrupt events raised by the user. - - Note, interrupt state is cleared after resuming. - - Attributes: - interrupts: Interrupts raised by the user. - context: Additional context associated with an interrupt event. - activated: True if agent is in an interrupt state, False otherwise. - """ - - interrupts: dict[str, Interrupt] = field(default_factory=dict) - context: dict[str, Any] = field(default_factory=dict) - activated: bool = False - - def activate(self, context: dict[str, Any] | None = None) -> None: - """Activate the interrupt state. - - Args: - context: Context associated with the interrupt event. - """ - self.context = context or {} - self.activated = True - - def deactivate(self) -> None: - """Deacitvate the interrupt state. - - Interrupts and context are cleared. - """ - self.interrupts = {} - self.context = {} - self.activated = False - - def to_dict(self) -> dict[str, Any]: - """Serialize to dict for session management.""" - return asdict(self) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "InterruptState": - """Initiailize interrupt state from serialized interrupt state. - - Interrupt state can be serialized with the `to_dict` method. - """ - return cls( - interrupts={ - interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() - }, - context=data["context"], - activated=data["activated"], - ) diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py index f0ed52389..919927e1a 100644 --- a/src/strands/interrupt.py +++ b/src/strands/interrupt.py @@ -1,7 +1,11 @@ """Human-in-the-loop interrupt system for agent workflows.""" -from dataclasses import asdict, dataclass -from typing import Any +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from .types.agent import AgentInput + from .types.interrupt import InterruptResponseContent @dataclass @@ -31,3 +35,89 @@ class InterruptException(Exception): def __init__(self, interrupt: Interrupt) -> None: """Set the interrupt.""" self.interrupt = interrupt + + +@dataclass +class _InterruptState: + """Track the state of interrupt events raised by the user. + + Note, interrupt state is cleared after resuming. + + Attributes: + interrupts: Interrupts raised by the user. + context: Additional context associated with an interrupt event. + activated: True if agent is in an interrupt state, False otherwise. + """ + + interrupts: dict[str, Interrupt] = field(default_factory=dict) + context: dict[str, Any] = field(default_factory=dict) + activated: bool = False + + def activate(self, context: dict[str, Any] | None = None) -> None: + """Activate the interrupt state. + + Args: + context: Context associated with the interrupt event. + """ + self.context = context or {} + self.activated = True + + def deactivate(self) -> None: + """Deacitvate the interrupt state. + + Interrupts and context are cleared. + """ + self.interrupts = {} + self.context = {} + self.activated = False + + def resume(self, prompt: "AgentInput") -> None: + """Configure the interrupt state if resuming from an interrupt event. + + Args: + prompt: User responses if resuming from interrupt. + + Raises: + TypeError: If in interrupt state but user did not provide responses. + """ + if not self.activated: + return + + if not isinstance(prompt, list): + raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's") + + invalid_types = [ + content_type for content in prompt for content_type in content if content_type != "interruptResponse" + ] + if invalid_types: + raise TypeError( + f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's" + ) + + contents = cast(list["InterruptResponseContent"], prompt) + for content in contents: + interrupt_id = content["interruptResponse"]["interruptId"] + interrupt_response = content["interruptResponse"]["response"] + + if interrupt_id not in self.interrupts: + raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") + + self.interrupts[interrupt_id].response = interrupt_response + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict for session management.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "_InterruptState": + """Initiailize interrupt state from serialized interrupt state. + + Interrupt state can be serialized with the `to_dict` method. + """ + return cls( + interrupts={ + interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() + }, + context=data["context"], + activated=data["activated"], + ) diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 4e72a1468..8b78ab448 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -7,7 +7,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Optional -from ..agent.interrupt import InterruptState +from ..interrupt import _InterruptState from .content import Message if TYPE_CHECKING: @@ -148,7 +148,7 @@ def to_dict(self) -> dict[str, Any]: def initialize_internal_state(self, agent: "Agent") -> None: """Initialize internal state of agent.""" if "interrupt_state" in self._internal_state: - agent._interrupt_state = InterruptState.from_dict(self._internal_state["interrupt_state"]) + agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"]) @dataclass diff --git a/tests/strands/agent/test_interrupt.py b/tests/strands/agent/test_interrupt.py deleted file mode 100644 index e248c29a6..000000000 --- a/tests/strands/agent/test_interrupt.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest - -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt - - -@pytest.fixture -def interrupt(): - return Interrupt(id="test_id", name="test_name", reason="test reason") - - -def test_interrupt_activate(): - interrupt_state = InterruptState() - - interrupt_state.activate(context={"test": "context"}) - - assert interrupt_state.activated - - tru_context = interrupt_state.context - exp_context = {"test": "context"} - assert tru_context == exp_context - - -def test_interrupt_deactivate(): - interrupt_state = InterruptState(context={"test": "context"}, activated=True) - - interrupt_state.deactivate() - - assert not interrupt_state.activated - - tru_context = interrupt_state.context - exp_context = {} - assert tru_context == exp_context - - -def test_interrupt_state_to_dict(interrupt): - interrupt_state = InterruptState(interrupts={"test_id": interrupt}, context={"test": "context"}, activated=True) - - tru_data = interrupt_state.to_dict() - exp_data = { - "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, - "context": {"test": "context"}, - "activated": True, - } - assert tru_data == exp_data - - -def test_interrupt_state_from_dict(): - data = { - "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, - "context": {"test": "context"}, - "activated": True, - } - - tru_state = InterruptState.from_dict(data) - exp_state = InterruptState( - interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, - context={"test": "context"}, - activated=True, - ) - assert tru_state == exp_state diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 09bacbcb0..9335f91a8 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -6,7 +6,6 @@ import strands import strands.telemetry -from strands.agent.interrupt import InterruptState from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -14,7 +13,7 @@ HookRegistry, MessageAddedEvent, ) -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState from strands.telemetry.metrics import EventLoopMetrics from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry @@ -143,7 +142,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.event_loop_metrics = EventLoopMetrics() mock.hooks = hook_registry mock.tool_executor = tool_executor - mock._interrupt_state = InterruptState() + mock._interrupt_state = _InterruptState() return mock diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 81c3bf2d3..3daf41734 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -2,9 +2,8 @@ import pytest -from strands.agent.interrupt import InterruptState from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState @pytest.fixture @@ -15,7 +14,7 @@ def registry(): @pytest.fixture def agent(): instance = unittest.mock.Mock() - instance._interrupt_state = InterruptState() + instance._interrupt_state = _InterruptState() return instance diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index ed0ec9072..451d0dd09 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -7,7 +7,7 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager -from strands.agent.interrupt import InterruptState +from strands.interrupt import _InterruptState from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock from strands.types.exceptions import SessionException @@ -131,7 +131,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): assert len(agent.messages) == 1 assert agent.messages[0]["role"] == "user" assert agent.messages[0]["content"][0]["text"] == "Hello" - assert agent._interrupt_state == InterruptState(interrupts={}, context={"test": "init"}, activated=False) + assert agent._interrupt_state == _InterruptState(interrupts={}, context={"test": "init"}, activated=False) def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager): diff --git a/tests/strands/test_interrupt.py b/tests/strands/test_interrupt.py index 8ce972103..a45d524e4 100644 --- a/tests/strands/test_interrupt.py +++ b/tests/strands/test_interrupt.py @@ -1,6 +1,6 @@ import pytest -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState @pytest.fixture @@ -22,3 +22,109 @@ def test_interrupt_to_dict(interrupt): "response": {"response": "test"}, } assert tru_dict == exp_dict + + +def test_interrupt_state_activate(): + interrupt_state = _InterruptState() + + interrupt_state.activate(context={"test": "context"}) + + assert interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {"test": "context"} + assert tru_context == exp_context + + +def test_interrupt_state_deactivate(): + interrupt_state = _InterruptState(context={"test": "context"}, activated=True) + + interrupt_state.deactivate() + + assert not interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {} + assert tru_context == exp_context + + +def test_interrupt_state_to_dict(): + interrupt_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + context={"test": "context"}, + activated=True, + ) + + tru_data = interrupt_state.to_dict() + exp_data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + assert tru_data == exp_data + + +def test_interrupt_state_from_dict(): + data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + + tru_state = _InterruptState.from_dict(data) + exp_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + context={"test": "context"}, + activated=True, + ) + assert tru_state == exp_state + + +def test_interrupt_state_resume(): + interrupt_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + activated=True, + ) + + prompt = [ + { + "interruptResponse": { + "interruptId": "test_id", + "response": "test response", + } + } + ] + interrupt_state.resume(prompt) + + tru_response = interrupt_state.interrupts["test_id"].response + exp_response = "test response" + assert tru_response == exp_response + + +def test_interrupt_state_resumse_deactivated(): + interrupt_state = _InterruptState(activated=False) + interrupt_state.resume([]) + + +def test_interrupt_state_resume_invalid_prompt(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"prompt_type= \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + interrupt_state.resume("invalid") + + +def test_interrupt_state_resume_invalid_content(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"content_types=<\['text'\]> \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + interrupt_state.resume([{"text": "invalid"}]) + + +def test_interrupt_resume_invalid_id(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"interrupt_id= \| no interrupt found" + with pytest.raises(KeyError, match=exp_message): + interrupt_state.resume([{"interruptResponse": {"interruptId": "invalid", "response": None}}]) diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index d25cf14bd..4d299a539 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -4,8 +4,8 @@ import pytest import strands -from strands.agent.interrupt import InterruptState from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry +from strands.interrupt import _InterruptState from strands.tools.registry import ToolRegistry from strands.types.tools import ToolContext @@ -104,7 +104,7 @@ def agent(tool_registry, hook_registry): mock_agent = unittest.mock.Mock() mock_agent.tool_registry = tool_registry mock_agent.hooks = hook_registry - mock_agent._interrupt_state = InterruptState() + mock_agent._interrupt_state = _InterruptState() return mock_agent diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index f89f1c945..1ab516006 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -10,8 +10,7 @@ import strands from strands import Agent -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState from strands.types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -151,7 +150,7 @@ async def test_stream_interrupt(alist): tool_use = {"toolUseId": "test_tool_id"} mock_agent = MagicMock() - mock_agent._interrupt_state = InterruptState() + mock_agent._interrupt_state = _InterruptState() invocation_state = {"agent": mock_agent} @@ -178,7 +177,7 @@ async def test_stream_interrupt_resume(alist): tool_use = {"toolUseId": "test_tool_id"} mock_agent = MagicMock() - mock_agent._interrupt_state = InterruptState(interrupts={interrupt.id: interrupt}) + mock_agent._interrupt_state = _InterruptState(interrupts={interrupt.id: interrupt}) invocation_state = {"agent": mock_agent} diff --git a/tests/strands/types/test_interrupt.py b/tests/strands/types/test_interrupt.py index ade0fa5e8..ad31384b6 100644 --- a/tests/strands/types/test_interrupt.py +++ b/tests/strands/types/test_interrupt.py @@ -2,8 +2,7 @@ import pytest -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt, InterruptException +from strands.interrupt import Interrupt, InterruptException, _InterruptState from strands.types.interrupt import _Interruptible @@ -20,7 +19,7 @@ def interrupt(): @pytest.fixture def agent(): instance = unittest.mock.Mock() - instance._interrupt_state = InterruptState() + instance._interrupt_state = _InterruptState() return instance diff --git a/tests/strands/types/test_session.py b/tests/strands/types/test_session.py index 26d4062e4..3e5360742 100644 --- a/tests/strands/types/test_session.py +++ b/tests/strands/types/test_session.py @@ -3,8 +3,8 @@ from uuid import uuid4 from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager -from strands.agent.interrupt import InterruptState from strands.agent.state import AgentState +from strands.interrupt import _InterruptState from strands.types.session import ( Session, SessionAgent, @@ -101,7 +101,7 @@ def test_session_agent_from_agent(): agent.agent_id = "a1" agent.conversation_manager = unittest.mock.Mock(get_state=lambda: {"test": "conversation"}) agent.state = AgentState({"test": "state"}) - agent._interrupt_state = InterruptState(interrupts={}, context={}, activated=False) + agent._interrupt_state = _InterruptState(interrupts={}, context={}, activated=False) tru_session_agent = SessionAgent.from_agent(agent) exp_session_agent = SessionAgent( @@ -127,5 +127,5 @@ def test_session_agent_initialize_internal_state(): session_agent.initialize_internal_state(agent) tru_interrupt_state = agent._interrupt_state - exp_interrupt_state = InterruptState(interrupts={}, context={"test": "init"}, activated=False) + exp_interrupt_state = _InterruptState(interrupts={}, context={"test": "init"}, activated=False) assert tru_interrupt_state == exp_interrupt_state