Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 3 additions & 36 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
HookRegistry,
MessageAddedEvent,
)
from ..interrupt import _InterruptState
from ..models.bedrock import BedrockModel
from ..models.model import Model
from ..session.session_manager import SessionManager
Expand All @@ -60,15 +61,13 @@
from ..types.agent import AgentInput
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
from ..types.exceptions import ContextWindowOverflowException
from ..types.interrupt import InterruptResponseContent
from ..types.tools import ToolResult, ToolUse
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .conversation_manager import (
ConversationManager,
SlidingWindowConversationManager,
)
from .interrupt import InterruptState
from .state import AgentState

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -352,7 +351,7 @@ def __init__(

self.hooks = HookRegistry()

self._interrupt_state = InterruptState()
self._interrupt_state = _InterruptState()

# Initialize session management functionality
self._session_manager = session_manager
Expand Down Expand Up @@ -640,7 +639,7 @@ async def stream_async(
yield event["data"]
```
"""
self._resume_interrupt(prompt)
self._interrupt_state.resume(prompt)

merged_state = {}
if kwargs:
Expand Down Expand Up @@ -683,38 +682,6 @@ async def stream_async(
self._end_agent_trace_span(error=e)
raise

def _resume_interrupt(self, prompt: AgentInput) -> None:
"""Configure the interrupt state if resuming from an interrupt event.

Args:
prompt: User responses if resuming from interrupt.

Raises:
TypeError: If in interrupt state but user did not provide responses.
"""
if not self._interrupt_state.activated:
return

if not isinstance(prompt, list):
raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's")

invalid_types = [
content_type for content in prompt for content_type in content if content_type != "interruptResponse"
]
if invalid_types:
raise TypeError(
f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's"
)

for content in cast(list[InterruptResponseContent], prompt):
interrupt_id = content["interruptResponse"]["interruptId"]
interrupt_response = content["interruptResponse"]["response"]

if interrupt_id not in self._interrupt_state.interrupts:
raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found")

self._interrupt_state.interrupts[interrupt_id].response = interrupt_response

async def _run_loop(
self,
messages: Messages,
Expand Down
59 changes: 0 additions & 59 deletions src/strands/agent/interrupt.py

This file was deleted.

94 changes: 92 additions & 2 deletions src/strands/interrupt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Human-in-the-loop interrupt system for agent workflows."""

from dataclasses import asdict, dataclass
from typing import Any
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, cast

if TYPE_CHECKING:
from .types.agent import AgentInput
from .types.interrupt import InterruptResponseContent


@dataclass
Expand Down Expand Up @@ -31,3 +35,89 @@ class InterruptException(Exception):
def __init__(self, interrupt: Interrupt) -> None:
"""Set the interrupt."""
self.interrupt = interrupt


@dataclass
class _InterruptState:
"""Track the state of interrupt events raised by the user.

Note, interrupt state is cleared after resuming.

Attributes:
interrupts: Interrupts raised by the user.
context: Additional context associated with an interrupt event.
activated: True if agent is in an interrupt state, False otherwise.
"""

interrupts: dict[str, Interrupt] = field(default_factory=dict)
context: dict[str, Any] = field(default_factory=dict)
activated: bool = False

def activate(self, context: dict[str, Any] | None = None) -> None:
"""Activate the interrupt state.

Args:
context: Context associated with the interrupt event.
"""
self.context = context or {}
self.activated = True

def deactivate(self) -> None:
"""Deacitvate the interrupt state.

Interrupts and context are cleared.
"""
self.interrupts = {}
self.context = {}
self.activated = False

def resume(self, prompt: "AgentInput") -> None:
"""Configure the interrupt state if resuming from an interrupt event.

Args:
prompt: User responses if resuming from interrupt.

Raises:
TypeError: If in interrupt state but user did not provide responses.
"""
if not self.activated:
return

if not isinstance(prompt, list):
raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's")

invalid_types = [
content_type for content in prompt for content_type in content if content_type != "interruptResponse"
]
if invalid_types:
raise TypeError(
f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's"
)

contents = cast(list["InterruptResponseContent"], prompt)
for content in contents:
interrupt_id = content["interruptResponse"]["interruptId"]
interrupt_response = content["interruptResponse"]["response"]

if interrupt_id not in self.interrupts:
raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found")

self.interrupts[interrupt_id].response = interrupt_response

def to_dict(self) -> dict[str, Any]:
"""Serialize to dict for session management."""
return asdict(self)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "_InterruptState":
"""Initiailize interrupt state from serialized interrupt state.

Interrupt state can be serialized with the `to_dict` method.
"""
return cls(
interrupts={
interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items()
},
context=data["context"],
activated=data["activated"],
)
4 changes: 2 additions & 2 deletions src/strands/types/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional

from ..agent.interrupt import InterruptState
from ..interrupt import _InterruptState
from .content import Message

if TYPE_CHECKING:
Expand Down Expand Up @@ -148,7 +148,7 @@ def to_dict(self) -> dict[str, Any]:
def initialize_internal_state(self, agent: "Agent") -> None:
"""Initialize internal state of agent."""
if "interrupt_state" in self._internal_state:
agent._interrupt_state = InterruptState.from_dict(self._internal_state["interrupt_state"])
agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"])


@dataclass
Expand Down
61 changes: 0 additions & 61 deletions tests/strands/agent/test_interrupt.py

This file was deleted.

5 changes: 2 additions & 3 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@

import strands
import strands.telemetry
from strands.agent.interrupt import InterruptState
from strands.hooks import (
AfterModelCallEvent,
BeforeModelCallEvent,
BeforeToolCallEvent,
HookRegistry,
MessageAddedEvent,
)
from strands.interrupt import Interrupt
from strands.interrupt import Interrupt, _InterruptState
from strands.telemetry.metrics import EventLoopMetrics
from strands.tools.executors import SequentialToolExecutor
from strands.tools.registry import ToolRegistry
Expand Down Expand Up @@ -143,7 +142,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis
mock.event_loop_metrics = EventLoopMetrics()
mock.hooks = hook_registry
mock.tool_executor = tool_executor
mock._interrupt_state = InterruptState()
mock._interrupt_state = _InterruptState()

return mock

Expand Down
5 changes: 2 additions & 3 deletions tests/strands/hooks/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import pytest

from strands.agent.interrupt import InterruptState
from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry
from strands.interrupt import Interrupt
from strands.interrupt import Interrupt, _InterruptState


@pytest.fixture
Expand All @@ -15,7 +14,7 @@ def registry():
@pytest.fixture
def agent():
instance = unittest.mock.Mock()
instance._interrupt_state = InterruptState()
instance._interrupt_state = _InterruptState()
return instance


Expand Down
4 changes: 2 additions & 2 deletions tests/strands/session/test_repository_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from strands.agent.agent import Agent
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager
from strands.agent.interrupt import InterruptState
from strands.interrupt import _InterruptState
from strands.session.repository_session_manager import RepositorySessionManager
from strands.types.content import ContentBlock
from strands.types.exceptions import SessionException
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_initialize_restores_existing_agent(session_manager, agent):
assert len(agent.messages) == 1
assert agent.messages[0]["role"] == "user"
assert agent.messages[0]["content"][0]["text"] == "Hello"
assert agent._interrupt_state == InterruptState(interrupts={}, context={"test": "init"}, activated=False)
assert agent._interrupt_state == _InterruptState(interrupts={}, context={"test": "init"}, activated=False)


def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager):
Expand Down
Loading
Loading