Skip to content

Commit 9baa8a4

Browse files
committed
Do not recompute toposort in every iteration of FusionOptimizer
It's not really needed as we never expand on the new nodes
1 parent a39dc8b commit 9baa8a4

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -625,10 +625,10 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
625625

626626
def find_fuseable_subgraph(
627627
*,
628-
fg: FunctionGraph,
629628
visited_nodes: set[Apply],
630629
fuseable_clients: FUSEABLE_MAPPING,
631630
unfuseable_clients: UNFUSEABLE_MAPPING,
631+
toposort_index: dict[Apply, int],
632632
) -> tuple[list[Variable], list[Variable]]:
633633
KT = TypeVar("KT")
634634
VT = TypeVar("VT", list, set)
@@ -648,8 +648,7 @@ def variables_depend_on(
648648
for a in ancestors(variables, blockers=stop_search_at)
649649
)
650650

651-
toposort = fg.toposort()
652-
for starting_node in toposort:
651+
for starting_node in toposort_index:
653652
if starting_node in visited_nodes:
654653
continue
655654

@@ -791,7 +790,7 @@ def variables_depend_on(
791790
and inp.owner not in visited_nodes
792791
)
793792
),
794-
key=lambda inp: toposort.index(inp.owner),
793+
key=lambda inp: toposort_index[inp.owner],
795794
reverse=True,
796795
):
797796
fuseable_nodes_to_visit.appendleft(inp.owner)
@@ -803,7 +802,7 @@ def variables_depend_on(
803802
for node in fuseable_clients_temp.get(next_out, ())
804803
if node not in visited_nodes
805804
),
806-
key=lambda node: toposort.index(node),
805+
key=lambda node: toposort_index[node],
807806
):
808807
fuseable_nodes_to_visit.append(next_node)
809808

@@ -877,20 +876,22 @@ def update_fuseable_mappings_after_fg_replace(
877876
# client (those that don't fit into 1))
878877
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
879878
visited_nodes: set[Apply] = set()
879+
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
880880
while True:
881-
starting_nodes = fg.apply_nodes.copy()
882881
try:
883882
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
884-
fg=fg,
885883
visited_nodes=visited_nodes,
886884
fuseable_clients=fuseable_clients,
887885
unfuseable_clients=unfuseable_clients,
886+
toposort_index=toposort_index,
888887
)
889888
except ValueError:
890889
return
891890
else:
892891
# The caller is now expected to update fg in place,
893892
# by replacing the subgraph with a Composite Op
893+
starting_nodes = fg.apply_nodes.copy()
894+
894895
yield subgraph_inputs, subgraph_outputs
895896

896897
# This is where we avoid repeated work by using a stateful

0 commit comments

Comments
 (0)