2727 out2in ,
2828)
2929from pytensor .graph .rewriting .db import SequenceDB
30- from pytensor .graph .traversal import ancestors
30+ from pytensor .graph .traversal import ancestors , toposort
3131from pytensor .graph .utils import InconsistencyError , MethodNotDefined
3232from pytensor .scalar .math import Grad2F1Loop , _grad_2f1_loop
3333from pytensor .tensor .basic import (
@@ -647,6 +647,7 @@ def find_fuseable_subgraph(
647647 visited_nodes : set [Apply ],
648648 fuseable_clients : FUSEABLE_MAPPING ,
649649 unfuseable_clients : UNFUSEABLE_MAPPING ,
650+ toposort_index : dict [Apply , int ],
650651 ) -> tuple [list [Variable ], list [Variable ]]:
651652 KT = TypeVar ("KT" )
652653 VT = TypeVar ("VT" , list , set )
@@ -661,29 +662,19 @@ def shallow_clone_defaultdict(
661662 def variables_depend_on (
662663 variables ,
663664 depend_on ,
664- stop_search_at = (),
665- variable_toposort = (),
665+ stop_search_at = None ,
666666 ) -> bool :
667667 # We can stop search at any variable that comes topologically before depend_on
668668 # As those can't logically be dependents anymore
669- depend_on = frozenset (depend_on )
670- first_depend_toposort_idx = next (
671- i for i , var in enumerate (variable_toposort ) if var in depend_on
672- )
673669 return any (
674670 a in depend_on
675671 for a in ancestors (
676672 variables ,
677- blockers = (
678- * stop_search_at ,
679- * variable_toposort [:first_depend_toposort_idx ],
680- ),
673+ blockers = stop_search_at ,
681674 )
682675 )
683676
684- toposort = fg .toposort ()
685- variable_toposort = None # build only lazily
686- for starting_node in toposort :
677+ for starting_node in toposort_index :
687678 if starting_node in visited_nodes :
688679 continue
689680
@@ -745,15 +736,10 @@ def variables_depend_on(
745736 # We need to check that any new inputs required by this node
746737 # do not depend on other outputs of the current subgraph,
747738 # via an unfuseable path.
748- if variable_toposort is None :
749- variable_toposort = [
750- o for node in toposort for o in node .outputs
751- ]
752739 if variables_depend_on (
753740 [next_out ],
754741 depend_on = unfuseable_clients_subgraph ,
755742 stop_search_at = subgraph_outputs ,
756- variable_toposort = variable_toposort ,
757743 ):
758744 must_backtrack = True
759745
@@ -773,14 +759,9 @@ def variables_depend_on(
773759 # We need to check that any inputs of the current subgraph
774760 # do not depend on other clients of this node,
775761 # via an unfuseable path.
776- if variable_toposort is None :
777- variable_toposort = [
778- o for node in toposort for o in node .outputs
779- ]
780762 if variables_depend_on (
781763 subgraph_inputs ,
782764 depend_on = new_implied_unfuseable_clients ,
783- variable_toposort = variable_toposort ,
784765 ):
785766 must_backtrack = True
786767
@@ -835,7 +816,7 @@ def variables_depend_on(
835816 and inp .owner not in visited_nodes
836817 )
837818 ),
838- key = lambda inp : toposort . index ( inp .owner ) ,
819+ key = lambda inp : toposort_index [ inp .owner ] ,
839820 reverse = True ,
840821 ):
841822 fuseable_nodes_to_visit .appendleft (inp .owner )
@@ -847,7 +828,7 @@ def variables_depend_on(
847828 for node in fuseable_clients_temp .get (next_out , ())
848829 if node not in visited_nodes
849830 ),
850- key = lambda node : toposort . index ( node ) ,
831+ key = lambda node : toposort_index [ node ] ,
851832 ):
852833 fuseable_nodes_to_visit .append (next_node )
853834
@@ -921,20 +902,25 @@ def update_fuseable_mappings_after_fg_replace(
921902 # client (those that don't fit into 1))
922903 fuseable_clients , unfuseable_clients = initialize_fuseable_mappings (fg = fg )
923904 visited_nodes : set [Apply ] = set ()
905+ toposort_index = {
906+ node : i for i , node in enumerate (toposort (fgraph .outputs ))
907+ }
924908 while True :
925- starting_nodes = fg .apply_nodes .copy ()
926909 try :
927910 subgraph_inputs , subgraph_outputs = find_fuseable_subgraph (
928911 fg = fg ,
929912 visited_nodes = visited_nodes ,
930913 fuseable_clients = fuseable_clients ,
931914 unfuseable_clients = unfuseable_clients ,
915+ toposort_index = toposort_index ,
932916 )
933917 except ValueError :
934918 return
935919 else :
936920 # The caller is now expected to update fg in place,
937921 # by replacing the subgraph with a Composite Op
922+ starting_nodes = fg .apply_nodes .copy ()
923+
938924 yield subgraph_inputs , subgraph_outputs
939925
940926 # This is where we avoid repeated work by using a stateful
0 commit comments