Skip to content

Commit 1db1f92

Browse files
committed
Exit from DestroyHandler orderings faster
1 parent cd213d1 commit 1db1f92

File tree

1 file changed

+95
-95
lines changed

1 file changed

+95
-95
lines changed

pytensor/graph/destroyhandler.py

Lines changed: 95 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -699,106 +699,106 @@ def orderings(self, fgraph, ordered=True):
699699
c) an Apply destroys (illegally) one of its own inputs by aliasing
700700
701701
"""
702+
if not self.destroyers:
703+
return {}
704+
702705
set_type = OrderedSet if ordered else set
703706
rval = {}
704707

705-
if self.destroyers:
706-
# BUILD DATA STRUCTURES
707-
# CHECK for multiple destructions during construction of variables
708-
709-
droot, impact, __ignore = self.refresh_droot_impact()
710-
711-
# check for destruction of constants
712-
illegal_destroy = [
713-
r
714-
for r in droot
715-
if getattr(r.tag, "indestructible", False) or isinstance(r, Constant)
716-
]
717-
if illegal_destroy:
718-
raise InconsistencyError(
719-
f"Attempting to destroy indestructible variables: {illegal_destroy}"
720-
)
708+
# BUILD DATA STRUCTURES
709+
# CHECK for multiple destructions during construction of variables
710+
711+
droot, impact, __ignore = self.refresh_droot_impact()
721712

722-
# add destroyed variable clients as computational dependencies
723-
for app in self.destroyers:
724-
# keep track of clients that should run before the current Apply
725-
root_clients = set_type()
726-
# for each destroyed input...
727-
for input_idx_list in app.op.destroy_map.values():
728-
destroyed_idx = input_idx_list[0]
729-
destroyed_variable = app.inputs[destroyed_idx]
730-
root = droot[destroyed_variable]
731-
root_impact = impact[root]
732-
# we generally want to put all clients of things which depend on root
733-
# as pre-requisites of app.
734-
# But, app is itself one such client!
735-
# App will always be a client of the node we're destroying
736-
# (destroyed_variable, but the tricky thing is when it is also a client of
737-
# *another variable* viewing on the root. Generally this is illegal, (e.g.,
738-
# add_inplace(x, x.T). In some special cases though, the in-place op will
739-
# actually be able to work properly with multiple destroyed inputs (e.g,
740-
# add_inplace(x, x). An Op that can still work in this case should declare
741-
# so via the 'destroyhandler_tolerate_same' attribute or
742-
# 'destroyhandler_tolerate_aliased' attribute.
743-
#
744-
# destroyhandler_tolerate_same should be a list of pairs of the form
745-
# [(idx0, idx1), (idx0, idx2), ...]
746-
# The first element of each pair is the input index of a destroyed
747-
# variable.
748-
# The second element of each pair is the index of a different input where
749-
# we will permit exactly the same variable to appear.
750-
# For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed
751-
# input is also allowed to appear as the second argument.
752-
#
753-
# destroyhandler_tolerate_aliased is the same sort of list of
754-
# pairs.
755-
# op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the
756-
# destroyhandler to IGNORE an aliasing between a destroyed
757-
# input idx0 and another input idx1.
758-
# This is generally a bad idea, but it is safe in some
759-
# cases, such as
760-
# - the op reads from the aliased idx1 before modifying idx0
761-
# - the idx0 and idx1 are guaranteed not to overlap (e.g.
762-
# they are pointed at different rows of a matrix).
763-
#
764-
765-
# CHECK FOR INPUT ALIASING
766-
# OPT: pre-compute this on import
767-
tolerate_same = getattr(app.op, "destroyhandler_tolerate_same", [])
768-
assert isinstance(tolerate_same, list)
769-
tolerated = {
770-
idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx
771-
}
772-
tolerated.add(destroyed_idx)
773-
tolerate_aliased = getattr(
774-
app.op, "destroyhandler_tolerate_aliased", []
775-
)
776-
assert isinstance(tolerate_aliased, list)
777-
ignored = {
778-
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
779-
}
780-
for i, input in enumerate(app.inputs):
781-
if i in ignored:
782-
continue
783-
if input in root_impact and (
784-
i not in tolerated or input is not destroyed_variable
785-
):
786-
raise InconsistencyError(
787-
f"Input aliasing: {app} ({destroyed_idx}, {i})"
788-
)
789-
790-
# add the rule: app must be preceded by all other Apply instances that
791-
# depend on destroyed_input
792-
for r in root_impact:
793-
assert not [a for a, c in self.clients[r].items() if not c]
794-
root_clients.update(
795-
[a for a, c in self.clients[r].items() if c]
713+
# check for destruction of constants
714+
illegal_destroy = [
715+
r
716+
for r in droot
717+
if getattr(r.tag, "indestructible", False) or isinstance(r, Constant)
718+
]
719+
if illegal_destroy:
720+
raise InconsistencyError(
721+
f"Attempting to destroy indestructible variables: {illegal_destroy}"
722+
)
723+
724+
# add destroyed variable clients as computational dependencies
725+
for app in self.destroyers:
726+
# keep track of clients that should run before the current Apply
727+
root_clients = set_type()
728+
# for each destroyed input...
729+
for input_idx_list in app.op.destroy_map.values():
730+
destroyed_idx = input_idx_list[0]
731+
destroyed_variable = app.inputs[destroyed_idx]
732+
root = droot[destroyed_variable]
733+
root_impact = impact[root]
734+
# we generally want to put all clients of things which depend on root
735+
# as pre-requisites of app.
736+
# But, app is itself one such client!
737+
# App will always be a client of the node we're destroying
738+
# (destroyed_variable, but the tricky thing is when it is also a client of
739+
# *another variable* viewing on the root. Generally this is illegal, (e.g.,
740+
# add_inplace(x, x.T). In some special cases though, the in-place op will
741+
# actually be able to work properly with multiple destroyed inputs (e.g,
742+
# add_inplace(x, x). An Op that can still work in this case should declare
743+
# so via the 'destroyhandler_tolerate_same' attribute or
744+
# 'destroyhandler_tolerate_aliased' attribute.
745+
#
746+
# destroyhandler_tolerate_same should be a list of pairs of the form
747+
# [(idx0, idx1), (idx0, idx2), ...]
748+
# The first element of each pair is the input index of a destroyed
749+
# variable.
750+
# The second element of each pair is the index of a different input where
751+
# we will permit exactly the same variable to appear.
752+
# For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed
753+
# input is also allowed to appear as the second argument.
754+
#
755+
# destroyhandler_tolerate_aliased is the same sort of list of
756+
# pairs.
757+
# op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the
758+
# destroyhandler to IGNORE an aliasing between a destroyed
759+
# input idx0 and another input idx1.
760+
# This is generally a bad idea, but it is safe in some
761+
# cases, such as
762+
# - the op reads from the aliased idx1 before modifying idx0
763+
# - the idx0 and idx1 are guaranteed not to overlap (e.g.
764+
# they are pointed at different rows of a matrix).
765+
#
766+
767+
# CHECK FOR INPUT ALIASING
768+
# OPT: pre-compute this on import
769+
tolerate_same = getattr(app.op, "destroyhandler_tolerate_same", [])
770+
assert isinstance(tolerate_same, list)
771+
tolerated = {
772+
idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx
773+
}
774+
tolerated.add(destroyed_idx)
775+
tolerate_aliased = getattr(
776+
app.op, "destroyhandler_tolerate_aliased", []
777+
)
778+
assert isinstance(tolerate_aliased, list)
779+
ignored = {
780+
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
781+
}
782+
for i, input in enumerate(app.inputs):
783+
if i in ignored:
784+
continue
785+
if input in root_impact and (
786+
i not in tolerated or input is not destroyed_variable
787+
):
788+
raise InconsistencyError(
789+
f"Input aliasing: {app} ({destroyed_idx}, {i})"
796790
)
797791

798-
# app itself is a client of the destroyed inputs,
799-
# but should not run before itself
800-
root_clients.remove(app)
801-
if root_clients:
802-
rval[app] = root_clients
792+
# add the rule: app must be preceded by all other Apply instances that
793+
# depend on destroyed_input
794+
for r in root_impact:
795+
assert not [a for a, c in self.clients[r].items() if not c]
796+
root_clients.update([a for a, c in self.clients[r].items() if c])
797+
798+
# app itself is a client of the destroyed inputs,
799+
# but should not run before itself
800+
root_clients.remove(app)
801+
if root_clients:
802+
rval[app] = root_clients
803803

804804
return rval

0 commit comments

Comments
 (0)