@@ -625,10 +625,10 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
625625
626626 def find_fuseable_subgraph (
627627 * ,
628- fg : FunctionGraph ,
629628 visited_nodes : set [Apply ],
630629 fuseable_clients : FUSEABLE_MAPPING ,
631630 unfuseable_clients : UNFUSEABLE_MAPPING ,
631+ toposort_index : dict [Apply , int ],
632632 ) -> tuple [list [Variable ], list [Variable ]]:
633633 KT = TypeVar ("KT" )
634634 VT = TypeVar ("VT" , list , set )
@@ -648,8 +648,7 @@ def variables_depend_on(
648648 for a in ancestors (variables , blockers = stop_search_at )
649649 )
650650
651- toposort = fg .toposort ()
652- for starting_node in toposort :
651+ for starting_node in toposort_index :
653652 if starting_node in visited_nodes :
654653 continue
655654
@@ -791,7 +790,7 @@ def variables_depend_on(
791790 and inp .owner not in visited_nodes
792791 )
793792 ),
794- key = lambda inp : toposort . index ( inp .owner ) ,
793+ key = lambda inp : toposort_index [ inp .owner ] ,
795794 reverse = True ,
796795 ):
797796 fuseable_nodes_to_visit .appendleft (inp .owner )
@@ -803,7 +802,7 @@ def variables_depend_on(
803802 for node in fuseable_clients_temp .get (next_out , ())
804803 if node not in visited_nodes
805804 ),
806- key = lambda node : toposort . index ( node ) ,
805+ key = lambda node : toposort_index [ node ] ,
807806 ):
808807 fuseable_nodes_to_visit .append (next_node )
809808
@@ -877,20 +876,22 @@ def update_fuseable_mappings_after_fg_replace(
877876 # client (those that don't fit into 1))
878877 fuseable_clients , unfuseable_clients = initialize_fuseable_mappings (fg = fg )
879878 visited_nodes : set [Apply ] = set ()
879+ toposort_index = {node : i for i , node in enumerate (fgraph .toposort ())}
880880 while True :
881- starting_nodes = fg .apply_nodes .copy ()
882881 try :
883882 subgraph_inputs , subgraph_outputs = find_fuseable_subgraph (
884- fg = fg ,
885883 visited_nodes = visited_nodes ,
886884 fuseable_clients = fuseable_clients ,
887885 unfuseable_clients = unfuseable_clients ,
886+ toposort_index = toposort_index ,
888887 )
889888 except ValueError :
890889 return
891890 else :
892891 # The caller is now expected to update fg in place,
893892 # by replacing the subgraph with a Composite Op
893+ starting_nodes = fg .apply_nodes .copy ()
894+
894895 yield subgraph_inputs , subgraph_outputs
895896
896897 # This is where we avoid repeated work by using a stateful
0 commit comments