Skip to content

Commit 95064eb

Browse files
committed
smarter variables_depend_on
1 parent 7a7d26f commit 95064eb

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,14 +659,30 @@ def shallow_clone_defaultdict(
659659
return new_dict
660660

661661
def variables_depend_on(
662-
variables, depend_on, stop_search_at=None
662+
variables,
663+
depend_on,
664+
stop_search_at=(),
665+
variable_toposort=(),
663666
) -> bool:
667+
# We can stop search at any variable that comes topologically before depend_on
668+
# As those can't logically be dependents anymore
669+
depend_on = frozenset(depend_on)
670+
first_depend_toposort_idx = next(
671+
i for i, var in enumerate(variable_toposort) if var in depend_on
672+
)
664673
return any(
665674
a in depend_on
666-
for a in ancestors(variables, blockers=stop_search_at)
675+
for a in ancestors(
676+
variables,
677+
blockers=(
678+
*stop_search_at,
679+
*variable_toposort[:first_depend_toposort_idx],
680+
),
681+
)
667682
)
668683

669684
toposort = fg.toposort()
685+
variable_toposort = None # build only lazily
670686
for starting_node in toposort:
671687
if starting_node in visited_nodes:
672688
continue
@@ -729,10 +745,15 @@ def variables_depend_on(
729745
# We need to check that any new inputs required by this node
730746
# do not depend on other outputs of the current subgraph,
731747
# via an unfuseable path.
748+
if variable_toposort is None:
749+
variable_toposort = [
750+
o for node in toposort for o in node.outputs
751+
]
732752
if variables_depend_on(
733753
[next_out],
734754
depend_on=unfuseable_clients_subgraph,
735755
stop_search_at=subgraph_outputs,
756+
variable_toposort=variable_toposort,
736757
):
737758
must_backtrack = True
738759

@@ -752,9 +773,14 @@ def variables_depend_on(
752773
# We need to check that any inputs of the current subgraph
753774
# do not depend on other clients of this node,
754775
# via an unfuseable path.
776+
if variable_toposort is None:
777+
variable_toposort = [
778+
o for node in toposort for o in node.outputs
779+
]
755780
if variables_depend_on(
756781
subgraph_inputs,
757782
depend_on=new_implied_unfuseable_clients,
783+
variable_toposort=variable_toposort,
758784
):
759785
must_backtrack = True
760786

0 commit comments

Comments
 (0)