@@ -659,14 +659,30 @@ def shallow_clone_defaultdict(
659659 return new_dict
660660
661661 def variables_depend_on (
662- variables , depend_on , stop_search_at = None
662+ variables ,
663+ depend_on ,
664+ stop_search_at = (),
665+ variable_toposort = (),
663666 ) -> bool :
667+ # We can stop search at any variable that comes topologically before depend_on
668+ # As those can't logically be dependents anymore
669+ depend_on = frozenset (depend_on )
670+ first_depend_toposort_idx = next (
671+ i for i , var in enumerate (variable_toposort ) if var in depend_on
672+ )
664673 return any (
665674 a in depend_on
666- for a in ancestors (variables , blockers = stop_search_at )
675+ for a in ancestors (
676+ variables ,
677+ blockers = (
678+ * stop_search_at ,
679+ * variable_toposort [:first_depend_toposort_idx ],
680+ ),
681+ )
667682 )
668683
669684 toposort = fg .toposort ()
685+ variable_toposort = None # build only lazily
670686 for starting_node in toposort :
671687 if starting_node in visited_nodes :
672688 continue
@@ -729,10 +745,15 @@ def variables_depend_on(
729745 # We need to check that any new inputs required by this node
730746 # do not depend on other outputs of the current subgraph,
731747 # via an unfuseable path.
748+ if variable_toposort is None :
749+ variable_toposort = [
750+ o for node in toposort for o in node .outputs
751+ ]
732752 if variables_depend_on (
733753 [next_out ],
734754 depend_on = unfuseable_clients_subgraph ,
735755 stop_search_at = subgraph_outputs ,
756+ variable_toposort = variable_toposort ,
736757 ):
737758 must_backtrack = True
738759
@@ -752,9 +773,14 @@ def variables_depend_on(
752773 # We need to check that any inputs of the current subgraph
753774 # do not depend on other clients of this node,
754775 # via an unfuseable path.
776+ if variable_toposort is None :
777+ variable_toposort = [
778+ o for node in toposort for o in node .outputs
779+ ]
755780 if variables_depend_on (
756781 subgraph_inputs ,
757782 depend_on = new_implied_unfuseable_clients ,
783+ variable_toposort = variable_toposort ,
758784 ):
759785 must_backtrack = True
760786
0 commit comments