@@ -59,6 +59,9 @@ def __init__(self, name, base_dir=None):
5959 super (Workflow , self ).__init__ (name , base_dir )
6060 self ._graph = nx .DiGraph ()
6161
62+ self ._nodes_cache = set ()
63+ self ._nested_workflows_cache = set ()
64+
6265 # PUBLIC API
6366 def clone (self , name ):
6467 """Clone a workflow
@@ -269,6 +272,8 @@ def connect(self, *args, **kwargs):
269272 "(%s, %s): new edge data: %s" , srcnode , destnode , str (edge_data )
270273 )
271274
275+ self ._update_node_cache ()
276+
272277 def disconnect (self , * args ):
273278 """Disconnect nodes
274279 See the docstring for connect for format.
@@ -314,6 +319,8 @@ def disconnect(self, *args):
314319 else :
315320 self ._graph .add_edges_from ([(srcnode , dstnode , edge_data )])
316321
322+ self ._update_node_cache ()
323+
317324 def add_nodes (self , nodes ):
318325 """ Add nodes to a workflow
319326
@@ -346,6 +353,7 @@ def add_nodes(self, nodes):
346353 if node ._hierarchy is None :
347354 node ._hierarchy = self .name
348355 self ._graph .add_nodes_from (newnodes )
356+ self ._update_node_cache ()
349357
350358 def remove_nodes (self , nodes ):
351359 """ Remove nodes from a workflow
@@ -356,6 +364,7 @@ def remove_nodes(self, nodes):
356364 A list of EngineBase-based objects
357365 """
358366 self ._graph .remove_nodes_from (nodes )
367+ self ._update_node_cache ()
359368
360369 # Input-Output access
361370 @property
@@ -903,23 +912,32 @@ def _set_node_input(self, node, param, source, sourceinfo):
903912 node .set_input (param , deepcopy (newval ))
904913
905914 def _get_all_nodes (self ):
906- allnodes = []
907- for node in self ._graph .nodes ():
908- if isinstance (node , Workflow ):
909- allnodes .extend (node ._get_all_nodes ())
910- else :
911- allnodes .append (node )
915+ allnodes = [
916+ * self ._nodes_cache .difference (self ._nested_workflows_cache )
917+ ] # all nodes that are not workflows
918+ for node in self ._nested_workflows_cache :
919+ allnodes .extend (node ._get_all_nodes ())
912920 return allnodes
913921
922+ def _update_node_cache (self ):
923+ nodes = set (self ._graph )
924+
925+ added_nodes = nodes .difference (self ._nodes_cache )
926+ removed_nodes = self ._nodes_cache .difference (nodes )
927+
928+ self ._nodes_cache = nodes
929+ self ._nested_workflows_cache .difference_update (removed_nodes )
930+
931+ for node in added_nodes :
932+ if isinstance (node , Workflow ):
933+ self ._nested_workflows_cache .add (node )
934+
914935 def _has_node (self , wanted_node ):
915- if wanted_node in self ._graph :
916- return True # best case scenario
917- for node in self ._graph : # iterate otherwise
918- if wanted_node == node :
936+ if wanted_node in self ._nodes_cache :
937+ return True
938+ for node in self ._nested_workflows_cache :
939+ if node . _has_node ( wanted_node ) :
919940 return True
920- if hasattr (node , "_has_node" ): # hasattr is faster than isinstance
921- if node ._has_node (wanted_node ):
922- return True
923941 return False
924942
925943 def _create_flat_graph (self ):
0 commit comments