diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 9f28876bf..c264533be 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -19,9 +19,10 @@ import logging import time from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Optional, Tuple, cast +from typing import Any, AsyncIterator, Callable, Optional, Tuple, Type, cast from opentelemetry import trace as trace_api +from pydantic import BaseModel from .._async import run_async from ..agent import Agent @@ -456,7 +457,11 @@ def __init__( run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: str | list[ContentBlock], + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> GraphResult: """Invoke the graph synchronously. @@ -464,15 +469,22 @@ def __call__( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model type for structured output. **kwargs: Keyword arguments allowing backward compatible future changes. """ if invocation_state is None: invocation_state = {} - return run_async(lambda: self.invoke_async(task, invocation_state)) + return run_async( + lambda: self.invoke_async(task, invocation_state, structured_output_model=structured_output_model) + ) async def invoke_async( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: str | list[ContentBlock], + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> GraphResult: """Invoke the graph asynchronously. @@ -483,9 +495,10 @@ async def invoke_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model type for structured output. **kwargs: Keyword arguments allowing backward compatible future changes. """ - events = self.stream_async(task, invocation_state, **kwargs) + events = self.stream_async(task, invocation_state, structured_output_model=structured_output_model, **kwargs) final_event = None async for event in events: final_event = event @@ -496,7 +509,11 @@ async def invoke_async( return cast(GraphResult, final_event["result"]) async def stream_async( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: str | list[ContentBlock], + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: """Stream events during graph execution. @@ -504,6 +521,7 @@ async def stream_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model type for structured output. **kwargs: Keyword arguments allowing backward compatible future changes. Yields: @@ -546,7 +564,9 @@ async def stream_async( self.node_timeout or "None", ) - async for event in self._execute_graph(invocation_state): + async for event in self._execute_graph( + invocation_state, structured_output_model=structured_output_model + ): yield event.as_dict() # Set final status based on execution results @@ -585,7 +605,9 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + async def _execute_graph( + self, invocation_state: dict[str, Any], structured_output_model: Type[BaseModel] | None = None + ) -> AsyncIterator[Any]: """Execute graph and yield TypedEvent objects.""" ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points) @@ -604,7 +626,9 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato ready_nodes.clear() # Execute current batch - async for event in self._execute_nodes_parallel(current_batch, invocation_state): + async for event in self._execute_nodes_parallel( + current_batch, invocation_state, structured_output_model=structured_output_model + ): yield event # Find newly ready nodes after batch execution @@ -628,7 +652,10 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato ready_nodes.extend(newly_ready) async def _execute_nodes_parallel( - self, nodes: list["GraphNode"], invocation_state: dict[str, Any] + self, + nodes: list["GraphNode"], + invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, ) -> AsyncIterator[Any]: """Execute multiple nodes in parallel and merge their event streams in real-time. @@ -638,7 +665,14 @@ async def _execute_nodes_parallel( event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue() # Start all node streams as independent tasks - tasks = [asyncio.create_task(self._stream_node_to_queue(node, event_queue, invocation_state)) for node in nodes] + tasks = [ + asyncio.create_task( + self._stream_node_to_queue( + node, event_queue, invocation_state, structured_output_model=structured_output_model + ) + ) + for node in nodes + ] try: # Consume events from the queue as they arrive @@ -689,6 +723,7 @@ async def _stream_node_to_queue( node: GraphNode, event_queue: asyncio.Queue[Any | None | Exception], invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, ) -> None: """Stream events from a node to the shared queue with optional timeout.""" try: @@ -696,7 +731,9 @@ async def _stream_node_to_queue( if self.node_timeout is not None: async def stream_node() -> None: - async for event in self._execute_node(node, invocation_state): + async for event in self._execute_node( + node, invocation_state, structured_output_model=structured_output_model + ): await event_queue.put(event) try: @@ -707,7 +744,9 @@ async def stream_node() -> None: await event_queue.put(timeout_exc) else: # No timeout - stream normally - async for event in self._execute_node(node, invocation_state): + async for event in self._execute_node( + node, invocation_state, structured_output_model=structured_output_model + ): await event_queue.put(event) except Exception as e: # Send exception through queue for fail-fast behavior @@ -774,7 +813,12 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ ) return False - async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + async def _execute_node( + self, + node: GraphNode, + invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, + ) -> AsyncIterator[Any]: """Execute a single node and yield TypedEvent objects.""" await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state)) @@ -802,7 +846,9 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if isinstance(node.executor, MultiAgentBase): # For nested multi-agent systems, stream their events and collect result multi_agent_result = None - async for event in node.executor.stream_async(node_input, invocation_state): + async for event in node.executor.stream_async( + node_input, invocation_state, structured_output_model=structured_output_model + ): # Forward nested multi-agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) yield wrapped_event @@ -824,9 +870,20 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) elif isinstance(node.executor, Agent): + # For agents, use agent's default structured_output_model if available, + # otherwise use the graph-level one + agent_structured_output_model = structured_output_model + if ( + hasattr(node.executor, "_default_structured_output_model") + and node.executor._default_structured_output_model is not None + ): + agent_structured_output_model = node.executor._default_structured_output_model + # For agents, stream their events and collect result agent_response = None - async for event in node.executor.stream_async(node_input, invocation_state=invocation_state): + async for event in node.executor.stream_async( + node_input, invocation_state=invocation_state, structured_output_model=agent_structured_output_model + ): # Forward agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) yield wrapped_event diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index b32356cb4..215fa9038 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2033,3 +2033,184 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): assert final_state["status"] == "completed" assert len(final_state["completed_nodes"]) == 1 assert "test_node" in final_state["node_results"] + + +class TestGraphStructuredOutput: + """Test Graph structured output functionality.""" + + @pytest.mark.asyncio + async def test_graph_passes_structured_output_model_to_agents(self, mock_strands_tracer, mock_use_span): + """Test that Graph passes structured_output_model to agent nodes.""" + from pydantic import BaseModel, Field + + class SimpleOutput(BaseModel): + """Simple output structure for testing.""" + + message: str = Field(description="A simple message") + number: int = Field(description="A number between 1 and 10") + + # Create a mock agent that captures structured_output_model + captured_structured_output_model = None + + async def mock_stream_async(*args, **kwargs): + nonlocal captured_structured_output_model + # structured_output_model is passed as a keyword argument + captured_structured_output_model = kwargs.get("structured_output_model") + + # Create a result with structured output + structured_result = SimpleOutput(message="Hello", number=7) + mock_result = AgentResult( + message={"role": "assistant", "content": [{"text": "Hello"}]}, + stop_reason="end_turn", + state={}, + metrics=Mock( + accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + accumulated_metrics={"latencyMs": 100.0}, + ), + structured_output=structured_result, + ) + yield {"result": mock_result} + + mock_agent = create_mock_agent("test_agent", "Test response") + # Assign the async generator function directly + mock_agent.stream_async = mock_stream_async + + # Build graph + builder = GraphBuilder() + builder.add_node(mock_agent, "test_node") + graph = builder.build() + + # Execute graph with structured_output_model + result = await graph.invoke_async("Test task", structured_output_model=SimpleOutput) + + # Verify structured_output_model was passed to agent + assert captured_structured_output_model == SimpleOutput + assert result.status == Status.COMPLETED + assert "test_node" in result.results + + # Verify structured output is accessible in the result + node_result = result.results["test_node"].result + assert isinstance(node_result, AgentResult) + assert node_result.structured_output is not None + assert isinstance(node_result.structured_output, SimpleOutput) + assert node_result.structured_output.message == "Hello" + assert node_result.structured_output.number == 7 + + @pytest.mark.asyncio + async def test_graph_structured_output_sync(self, mock_strands_tracer, mock_use_span): + """Test Graph structured output with synchronous __call__.""" + from pydantic import BaseModel, Field + + class SummaryOutput(BaseModel): + """Summary output structure for testing.""" + + summary: str = Field(description="A summary") + key_points: list[str] = Field(description="List of key points") + + # Create a mock agent that captures structured_output_model + captured_structured_output_model = None + + async def mock_stream_async(*args, **kwargs): + nonlocal captured_structured_output_model + captured_structured_output_model = kwargs.get("structured_output_model") + + structured_result = SummaryOutput(summary="Test summary", key_points=["point1", "point2"]) + mock_result = AgentResult( + message={"role": "assistant", "content": [{"text": "Summary"}]}, + stop_reason="end_turn", + state={}, + metrics=Mock( + accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + accumulated_metrics={"latencyMs": 100.0}, + ), + structured_output=structured_result, + ) + yield {"result": mock_result} + + mock_agent = create_mock_agent("summary_agent", "Summary response") + # Assign the async generator function directly + mock_agent.stream_async = mock_stream_async + + # Build graph + builder = GraphBuilder() + builder.add_node(mock_agent, "summary_node") + graph = builder.build() + + # Execute graph with structured_output_model using __call__ + result = graph("Test task", structured_output_model=SummaryOutput) + + # Verify structured_output_model was passed to agent + assert captured_structured_output_model == SummaryOutput + assert result.status == Status.COMPLETED + assert "summary_node" in result.results + + # Verify structured output is accessible + node_result = result.results["summary_node"].result + assert isinstance(node_result, AgentResult) + assert node_result.structured_output is not None + assert isinstance(node_result.structured_output, SummaryOutput) + assert node_result.structured_output.summary == "Test summary" + assert len(node_result.structured_output.key_points) == 2 + + @pytest.mark.asyncio + async def test_graph_structured_output_multiple_nodes(self, mock_strands_tracer, mock_use_span): + """Test Graph structured output with multiple agent nodes.""" + from pydantic import BaseModel, Field + + class NodeOutput(BaseModel): + """Node output structure for testing.""" + + node_id: str = Field(description="Node identifier") + output: str = Field(description="Node output") + + captured_models = {} + + def create_mock_agent_with_capture(name): + async def mock_stream_async(*args, **kwargs): + captured_models[name] = kwargs.get("structured_output_model") + + structured_result = NodeOutput(node_id=name, output=f"Output from {name}") + mock_result = AgentResult( + message={"role": "assistant", "content": [{"text": f"Output from {name}"}]}, + stop_reason="end_turn", + state={}, + metrics=Mock( + accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + accumulated_metrics={"latencyMs": 100.0}, + ), + structured_output=structured_result, + ) + yield {"result": mock_result} + + agent = create_mock_agent(name, f"Response from {name}") + # Assign the async generator function directly + agent.stream_async = mock_stream_async + return agent + + # Create multiple agents + agent1 = create_mock_agent_with_capture("agent1") + agent2 = create_mock_agent_with_capture("agent2") + + # Build graph with multiple nodes + builder = GraphBuilder() + builder.add_node(agent1, "node1") + builder.add_node(agent2, "node2") + graph = builder.build() + + # Execute graph with structured_output_model + result = await graph.invoke_async("Test task", structured_output_model=NodeOutput) + + # Verify structured_output_model was passed to both agents + assert captured_models["agent1"] == NodeOutput + assert captured_models["agent2"] == NodeOutput + assert result.status == Status.COMPLETED + assert "node1" in result.results + assert "node2" in result.results + + # Verify structured output is accessible for both nodes + node1_result = result.results["node1"].result + node2_result = result.results["node2"].result + assert node1_result.structured_output is not None + assert node2_result.structured_output is not None + assert node1_result.structured_output.node_id == "agent1" + assert node2_result.structured_output.node_id == "agent2"