2222from typing import Any , Callable , Tuple , cast
2323
2424from ..agent import Agent , AgentResult
25+ from ..types .content import ContentBlock
2526from ..types .event_loop import Metrics , Usage
2627from .base import MultiAgentBase , MultiAgentResult , NodeResult , Status
2728
@@ -42,12 +43,14 @@ class GraphState:
4243 Entry point nodes receive this task as their input if they have no dependencies.
4344 """
4445
46+ # Task (with default empty string)
47+ task : str | list [ContentBlock ] = ""
48+
4549 # Execution state
4650 status : Status = Status .PENDING
4751 completed_nodes : set ["GraphNode" ] = field (default_factory = set )
4852 failed_nodes : set ["GraphNode" ] = field (default_factory = set )
4953 execution_order : list ["GraphNode" ] = field (default_factory = list )
50- task : str = ""
5154
5255 # Results
5356 results : dict [str , NodeResult ] = field (default_factory = dict )
@@ -247,7 +250,7 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi
247250 self .entry_points = entry_points
248251 self .state = GraphState ()
249252
250- def execute (self , task : str ) -> GraphResult :
253+ def execute (self , task : str | list [ ContentBlock ] ) -> GraphResult :
251254 """Execute task synchronously."""
252255
253256 def execute () -> GraphResult :
@@ -257,7 +260,7 @@ def execute() -> GraphResult:
257260 future = executor .submit (execute )
258261 return future .result ()
259262
260- async def execute_async (self , task : str ) -> GraphResult :
263+ async def execute_async (self , task : str | list [ ContentBlock ] ) -> GraphResult :
261264 """Execute the graph asynchronously."""
262265 logger .debug ("task=<%s> | starting graph execution" , task )
263266
@@ -435,8 +438,8 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None:
435438 self .state .accumulated_metrics ["latencyMs" ] += node_result .accumulated_metrics .get ("latencyMs" , 0 )
436439 self .state .execution_count += node_result .execution_count
437440
438- def _build_node_input (self , node : GraphNode ) -> str :
439- """Build input text for a node based on dependency outputs."""
441+ def _build_node_input (self , node : GraphNode ) -> list [ ContentBlock ] :
442+ """Build input for a node based on dependency outputs."""
440443 # Get satisfied dependencies
441444 dependency_results = {}
442445 for edge in self .edges :
@@ -449,21 +452,36 @@ def _build_node_input(self, node: GraphNode) -> str:
449452 dependency_results [edge .from_node .node_id ] = self .state .results [edge .from_node .node_id ]
450453
451454 if not dependency_results :
452- return self .state .task
455+ # No dependencies - return task as ContentBlocks
456+ if isinstance (self .state .task , str ):
457+ return [ContentBlock (text = self .state .task )]
458+ else :
459+ return self .state .task
453460
454461 # Combine task with dependency outputs
455- input_parts = [f"Original Task: { self .state .task } " , "\n Inputs from previous nodes:" ]
462+ node_input = []
463+
464+ # Add original task
465+ if isinstance (self .state .task , str ):
466+ node_input .append (ContentBlock (text = f"Original Task: { self .state .task } " ))
467+ else :
468+ # Add task content blocks with a prefix
469+ node_input .append (ContentBlock (text = "Original Task:" ))
470+ node_input .extend (self .state .task )
471+
472+ # Add dependency outputs
473+ node_input .append (ContentBlock (text = "\n Inputs from previous nodes:" ))
456474
457475 for dep_id , node_result in dependency_results .items ():
458- input_parts .append (f"\n From { dep_id } :" )
476+ node_input .append (ContentBlock ( text = f"\n From { dep_id } :" ) )
459477 # Get all agent results from this node (flattened if nested)
460478 agent_results = node_result .get_agent_results ()
461479 for result in agent_results :
462480 agent_name = getattr (result , "agent_name" , "Agent" )
463481 result_text = str (result )
464- input_parts .append (f" - { agent_name } : { result_text } " )
482+ node_input .append (ContentBlock ( text = f" - { agent_name } : { result_text } " ) )
465483
466- return " \n " . join ( input_parts )
484+ return node_input
467485
468486 def _build_result (self ) -> GraphResult :
469487 """Build graph result from current state."""
0 commit comments