Skip to content

Commit 6c2fc26

Browse files
committed
.don't compute toposort in every iteration
1 parent b4c7afa commit 6c2fc26

File tree

1 file changed

+13
-27
lines changed

1 file changed

+13
-27
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
out2in,
2828
)
2929
from pytensor.graph.rewriting.db import SequenceDB
30-
from pytensor.graph.traversal import ancestors
30+
from pytensor.graph.traversal import ancestors, toposort
3131
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
3232
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
3333
from pytensor.tensor.basic import (
@@ -647,6 +647,7 @@ def find_fuseable_subgraph(
647647
visited_nodes: set[Apply],
648648
fuseable_clients: FUSEABLE_MAPPING,
649649
unfuseable_clients: UNFUSEABLE_MAPPING,
650+
toposort_index: dict[Apply, int],
650651
) -> tuple[list[Variable], list[Variable]]:
651652
KT = TypeVar("KT")
652653
VT = TypeVar("VT", list, set)
@@ -661,29 +662,19 @@ def shallow_clone_defaultdict(
661662
def variables_depend_on(
662663
variables,
663664
depend_on,
664-
stop_search_at=(),
665-
variable_toposort=(),
665+
stop_search_at=None,
666666
) -> bool:
667667
# We can stop search at any variable that comes topologically before depend_on
668668
# As those can't logically be dependents anymore
669-
depend_on = frozenset(depend_on)
670-
first_depend_toposort_idx = next(
671-
i for i, var in enumerate(variable_toposort) if var in depend_on
672-
)
673669
return any(
674670
a in depend_on
675671
for a in ancestors(
676672
variables,
677-
blockers=(
678-
*stop_search_at,
679-
*variable_toposort[:first_depend_toposort_idx],
680-
),
673+
blockers=stop_search_at,
681674
)
682675
)
683676

684-
toposort = fg.toposort()
685-
variable_toposort = None # build only lazily
686-
for starting_node in toposort:
677+
for starting_node in toposort_index:
687678
if starting_node in visited_nodes:
688679
continue
689680

@@ -745,15 +736,10 @@ def variables_depend_on(
745736
# We need to check that any new inputs required by this node
746737
# do not depend on other outputs of the current subgraph,
747738
# via an unfuseable path.
748-
if variable_toposort is None:
749-
variable_toposort = [
750-
o for node in toposort for o in node.outputs
751-
]
752739
if variables_depend_on(
753740
[next_out],
754741
depend_on=unfuseable_clients_subgraph,
755742
stop_search_at=subgraph_outputs,
756-
variable_toposort=variable_toposort,
757743
):
758744
must_backtrack = True
759745

@@ -773,14 +759,9 @@ def variables_depend_on(
773759
# We need to check that any inputs of the current subgraph
774760
# do not depend on other clients of this node,
775761
# via an unfuseable path.
776-
if variable_toposort is None:
777-
variable_toposort = [
778-
o for node in toposort for o in node.outputs
779-
]
780762
if variables_depend_on(
781763
subgraph_inputs,
782764
depend_on=new_implied_unfuseable_clients,
783-
variable_toposort=variable_toposort,
784765
):
785766
must_backtrack = True
786767

@@ -835,7 +816,7 @@ def variables_depend_on(
835816
and inp.owner not in visited_nodes
836817
)
837818
),
838-
key=lambda inp: toposort.index(inp.owner),
819+
key=lambda inp: toposort_index[inp.owner],
839820
reverse=True,
840821
):
841822
fuseable_nodes_to_visit.appendleft(inp.owner)
@@ -847,7 +828,7 @@ def variables_depend_on(
847828
for node in fuseable_clients_temp.get(next_out, ())
848829
if node not in visited_nodes
849830
),
850-
key=lambda node: toposort.index(node),
831+
key=lambda node: toposort_index[node],
851832
):
852833
fuseable_nodes_to_visit.append(next_node)
853834

@@ -921,20 +902,25 @@ def update_fuseable_mappings_after_fg_replace(
921902
# client (those that don't fit into 1))
922903
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
923904
visited_nodes: set[Apply] = set()
905+
toposort_index = {
906+
node: i for i, node in enumerate(toposort(fgraph.outputs))
907+
}
924908
while True:
925-
starting_nodes = fg.apply_nodes.copy()
926909
try:
927910
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
928911
fg=fg,
929912
visited_nodes=visited_nodes,
930913
fuseable_clients=fuseable_clients,
931914
unfuseable_clients=unfuseable_clients,
915+
toposort_index=toposort_index,
932916
)
933917
except ValueError:
934918
return
935919
else:
936920
# The caller is now expected to update fg in place,
937921
# by replacing the subgraph with a Composite Op
922+
starting_nodes = fg.apply_nodes.copy()
923+
938924
yield subgraph_inputs, subgraph_outputs
939925

940926
# This is where we avoid repeated work by using a stateful

0 commit comments

Comments
 (0)