1- from typing import Optional
1+ from __future__ import annotations
22
33import graphviz # type: ignore
44
@@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str:
3131 return "" .join (parts )
3232
3333
34- def get_all_nodes (agent : Agent , parent : Optional [Agent ] = None ) -> str :
34+ def get_all_nodes (
35+ agent : Agent , parent : Agent | None = None , visited : set [str ] | None = None
36+ ) -> str :
3537 """
3638 Recursively generates the nodes for the given agent and its handoffs in DOT format.
3739
@@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
4143 Returns:
4244 str: The DOT format string representing the nodes.
4345 """
46+ if visited is None :
47+ visited = set ()
48+ if agent .name in visited :
49+ return ""
50+ visited .add (agent .name )
51+
4452 parts = []
4553
4654 # Start and end the graph
47- parts .append (
48- '"__start__" [label="__start__", shape=ellipse, style=filled, '
49- "fillcolor=lightblue, width=0.5, height=0.3];"
50- '"__end__" [label="__end__", shape=ellipse, style=filled, '
51- "fillcolor=lightblue, width=0.5, height=0.3];"
52- )
53- # Ensure parent agent node is colored
5455 if not parent :
56+ parts .append (
57+ '"__start__" [label="__start__", shape=ellipse, style=filled, '
58+ "fillcolor=lightblue, width=0.5, height=0.3];"
59+ '"__end__" [label="__end__", shape=ellipse, style=filled, '
60+ "fillcolor=lightblue, width=0.5, height=0.3];"
61+ )
62+ # Ensure parent agent node is colored
5563 parts .append (
5664 f'"{ agent .name } " [label="{ agent .name } ", shape=box, style=filled, '
5765 "fillcolor=lightyellow, width=1.5, height=0.8];"
@@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
7179 f"fillcolor=lightyellow, width=1.5, height=0.8];"
7280 )
7381 if isinstance (handoff , Agent ):
74- parts .append (
75- f'"{ handoff .name } " [label="{ handoff .name } ", '
76- f"shape=box, style=filled, style=rounded, "
77- f"fillcolor=lightyellow, width=1.5, height=0.8];"
78- )
79- parts .append (get_all_nodes (handoff ))
82+ if handoff .name not in visited :
83+ parts .append (
84+ f'"{ handoff .name } " [label="{ handoff .name } ", '
85+ f"shape=box, style=filled, style=rounded, "
86+ f"fillcolor=lightyellow, width=1.5, height=0.8];"
87+ )
88+ parts .append (get_all_nodes (handoff , agent , visited ))
8089
8190 return "" .join (parts )
8291
8392
84- def get_all_edges (agent : Agent , parent : Optional [Agent ] = None ) -> str :
93+ def get_all_edges (
94+ agent : Agent , parent : Agent | None = None , visited : set [str ] | None = None
95+ ) -> str :
8596 """
8697 Recursively generates the edges for the given agent and its handoffs in DOT format.
8798
@@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
92103 Returns:
93104 str: The DOT format string representing the edges.
94105 """
106+ if visited is None :
107+ visited = set ()
108+ if agent .name in visited :
109+ return ""
110+ visited .add (agent .name )
111+
95112 parts = []
96113
97114 if not parent :
@@ -109,15 +126,15 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
109126 if isinstance (handoff , Agent ):
110127 parts .append (f"""
111128 "{ agent .name } " -> "{ handoff .name } ";""" )
112- parts .append (get_all_edges (handoff , agent ))
129+ parts .append (get_all_edges (handoff , agent , visited ))
113130
114131 if not agent .handoffs and not isinstance (agent , Tool ): # type: ignore
115132 parts .append (f'"{ agent .name } " -> "__end__";' )
116133
117134 return "" .join (parts )
118135
119136
120- def draw_graph (agent : Agent , filename : Optional [ str ] = None ) -> graphviz .Source :
137+ def draw_graph (agent : Agent , filename : str | None = None ) -> graphviz .Source :
121138 """
122139 Draws the graph for the given agent and optionally saves it as a PNG file.
123140
0 commit comments