Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 73 additions & 16 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import logging
import time
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Callable, Optional, Tuple, cast
from typing import Any, AsyncIterator, Callable, Optional, Tuple, Type, cast

from opentelemetry import trace as trace_api
from pydantic import BaseModel

from .._async import run_async
from ..agent import Agent
Expand Down Expand Up @@ -456,23 +457,34 @@ def __init__(
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))

def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self,
task: str | list[ContentBlock],
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> GraphResult:
"""Invoke the graph synchronously.
Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
structured_output_model: Pydantic model type for structured output.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
if invocation_state is None:
invocation_state = {}

return run_async(lambda: self.invoke_async(task, invocation_state))
return run_async(
lambda: self.invoke_async(task, invocation_state, structured_output_model=structured_output_model)
)

async def invoke_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self,
task: str | list[ContentBlock],
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> GraphResult:
"""Invoke the graph asynchronously.
Expand All @@ -483,9 +495,10 @@ async def invoke_async(
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
structured_output_model: Pydantic model type for structured output.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
events = self.stream_async(task, invocation_state, **kwargs)
events = self.stream_async(task, invocation_state, structured_output_model=structured_output_model, **kwargs)
final_event = None
async for event in events:
final_event = event
Expand All @@ -496,14 +509,19 @@ async def invoke_async(
return cast(GraphResult, final_event["result"])

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self,
task: str | list[ContentBlock],
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during graph execution.
Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
structured_output_model: Pydantic model type for structured output.
**kwargs: Keyword arguments allowing backward compatible future changes.
Yields:
Expand Down Expand Up @@ -546,7 +564,9 @@ async def stream_async(
self.node_timeout or "None",
)

async for event in self._execute_graph(invocation_state):
async for event in self._execute_graph(
invocation_state, structured_output_model=structured_output_model
):
yield event.as_dict()

# Set final status based on execution results
Expand Down Expand Up @@ -585,7 +605,9 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
# Validate Agent-specific constraints for each node
_validate_node_executor(node.executor)

async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
async def _execute_graph(
self, invocation_state: dict[str, Any], structured_output_model: Type[BaseModel] | None = None
) -> AsyncIterator[Any]:
"""Execute graph and yield TypedEvent objects."""
ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points)

Expand All @@ -604,7 +626,9 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato
ready_nodes.clear()

# Execute current batch
async for event in self._execute_nodes_parallel(current_batch, invocation_state):
async for event in self._execute_nodes_parallel(
current_batch, invocation_state, structured_output_model=structured_output_model
):
yield event

# Find newly ready nodes after batch execution
Expand All @@ -628,7 +652,10 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato
ready_nodes.extend(newly_ready)

async def _execute_nodes_parallel(
self, nodes: list["GraphNode"], invocation_state: dict[str, Any]
self,
nodes: list["GraphNode"],
invocation_state: dict[str, Any],
structured_output_model: Type[BaseModel] | None = None,
) -> AsyncIterator[Any]:
"""Execute multiple nodes in parallel and merge their event streams in real-time.
Expand All @@ -638,7 +665,14 @@ async def _execute_nodes_parallel(
event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue()

# Start all node streams as independent tasks
tasks = [asyncio.create_task(self._stream_node_to_queue(node, event_queue, invocation_state)) for node in nodes]
tasks = [
asyncio.create_task(
self._stream_node_to_queue(
node, event_queue, invocation_state, structured_output_model=structured_output_model
)
)
for node in nodes
]

try:
# Consume events from the queue as they arrive
Expand Down Expand Up @@ -689,14 +723,17 @@ async def _stream_node_to_queue(
node: GraphNode,
event_queue: asyncio.Queue[Any | None | Exception],
invocation_state: dict[str, Any],
structured_output_model: Type[BaseModel] | None = None,
) -> None:
"""Stream events from a node to the shared queue with optional timeout."""
try:
# Apply timeout to the entire streaming process if configured
if self.node_timeout is not None:

async def stream_node() -> None:
async for event in self._execute_node(node, invocation_state):
async for event in self._execute_node(
node, invocation_state, structured_output_model=structured_output_model
):
await event_queue.put(event)

try:
Expand All @@ -707,7 +744,9 @@ async def stream_node() -> None:
await event_queue.put(timeout_exc)
else:
# No timeout - stream normally
async for event in self._execute_node(node, invocation_state):
async for event in self._execute_node(
node, invocation_state, structured_output_model=structured_output_model
):
await event_queue.put(event)
except Exception as e:
# Send exception through queue for fail-fast behavior
Expand Down Expand Up @@ -774,7 +813,12 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
)
return False

async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
async def _execute_node(
self,
node: GraphNode,
invocation_state: dict[str, Any],
structured_output_model: Type[BaseModel] | None = None,
) -> AsyncIterator[Any]:
"""Execute a single node and yield TypedEvent objects."""
await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state))

Expand Down Expand Up @@ -802,7 +846,9 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
if isinstance(node.executor, MultiAgentBase):
# For nested multi-agent systems, stream their events and collect result
multi_agent_result = None
async for event in node.executor.stream_async(node_input, invocation_state):
async for event in node.executor.stream_async(
node_input, invocation_state, structured_output_model=structured_output_model
):
# Forward nested multi-agent events with node context
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
yield wrapped_event
Expand All @@ -824,9 +870,20 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
)

elif isinstance(node.executor, Agent):
# For agents, use agent's default structured_output_model if available,
# otherwise use the graph-level one
agent_structured_output_model = structured_output_model
if (
hasattr(node.executor, "_default_structured_output_model")
and node.executor._default_structured_output_model is not None
):
agent_structured_output_model = node.executor._default_structured_output_model

# For agents, stream their events and collect result
agent_response = None
async for event in node.executor.stream_async(node_input, invocation_state=invocation_state):
async for event in node.executor.stream_async(
node_input, invocation_state=invocation_state, structured_output_model=agent_structured_output_model
):
# Forward agent events with node context
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
yield wrapped_event
Expand Down
Loading