Skip to content

Commit d73debf

Browse files
committed
Copy on write in FusionOptimizer
1 parent 824af00 commit d73debf

File tree

1 file changed

+61
-21
lines changed

1 file changed

+61
-21
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import itertools
33
import operator
44
import sys
5+
import typing
56
from collections import defaultdict, deque
67
from collections.abc import Generator, Sequence
78
from functools import cache, reduce
@@ -522,6 +523,43 @@ def elemwise_max_operands_fct(node) -> int:
522523
return 1024
523524

524525

526+
class CopyOnWriteDictOfSets:
527+
__slots__ = ("d", "d_copy")
528+
529+
def __init__(self, d: dict[typing.Any, set]):
530+
self.d = d
531+
self.d_copy: dict[typing.Any, set] = {}
532+
533+
def __getitem__(self, key):
534+
try:
535+
return self.d_copy[key]
536+
except KeyError:
537+
return self.d[key]
538+
539+
def get(self, key, default=frozenset()):
540+
try:
541+
return self.d_copy[key]
542+
except KeyError:
543+
try:
544+
return self.d[key]
545+
except KeyError:
546+
return default
547+
548+
def remove_from_key(self, key, value):
549+
try:
550+
self.d_copy[key].remove(value)
551+
except KeyError:
552+
self.d_copy[key] = copied_value = self.d[key].copy()
553+
copied_value.remove(value)
554+
555+
def add_to_key(self, key, value):
556+
try:
557+
self.d_copy[key].add(value)
558+
except KeyError:
559+
self.d_copy[key] = copied_value = self.d[key].copy()
560+
copied_value.add(value)
561+
562+
525563
class FusionOptimizer(GraphRewriter):
526564
"""Graph optimizer that fuses consecutive Elemwise operations."""
527565

@@ -648,15 +686,10 @@ def variables_depend_on(
648686
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
649687
unfuseable_clients_subgraph: set[Variable] = set()
650688

651-
# Shallow cloning of maps so that they can be manipulated in place
652-
fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set)
653-
fuseable_clients_clone.update(
654-
{k: v.copy() for k, v in fuseable_clients.items()}
655-
)
656-
unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set)
657-
unfuseable_clients_clone.update(
658-
{k: v.copy() for k, v in unfuseable_clients.items()}
659-
)
689+
# If we need to manipulate the maps in place, we'll do a shallow copy later
690+
# For now we query on the original ones
691+
fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients)
692+
unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients)
660693

661694
# We now try to expand as much as possible towards the potentially
662695
# fuseable clients and ancestors to detect the largest possible
@@ -686,7 +719,7 @@ def variables_depend_on(
686719
required_unfuseable_inputs = [
687720
inp
688721
for inp in next_node.inputs
689-
if next_node in unfuseable_clients_clone.get(inp, ())
722+
if next_node in unfuseable_clients_clone.get(inp)
690723
]
691724
new_required_unfuseable_inputs = [
692725
inp
@@ -709,7 +742,7 @@ def variables_depend_on(
709742
if not must_backtrack:
710743
implied_unfuseable_clients = {
711744
c
712-
for client in unfuseable_clients_clone.get(next_out, ())
745+
for client in unfuseable_clients_clone.get(next_out)
713746
if not isinstance(client.op, Output)
714747
for c in client.outputs
715748
}
@@ -730,13 +763,15 @@ def variables_depend_on(
730763

731764
if must_backtrack:
732765
for inp in next_node.inputs:
733-
if (
734-
inp.owner in visited_nodes
735-
# next_node could have the same input repeated
736-
and next_node in fuseable_clients_clone[inp]
737-
):
738-
fuseable_clients_clone[inp].remove(next_node)
739-
unfuseable_clients_clone[inp].add(next_node)
766+
if inp.owner in visited_nodes:
767+
if next_node not in fuseable_clients_clone[inp]:
768+
# This can happen when next node has repeated inputs
769+
continue
770+
fuseable_clients_clone.remove_from_key(
771+
inp, next_node
772+
)
773+
unfuseable_clients_clone.add_to_key(inp, next_node)
774+
740775
# This input must become an output of the subgraph,
741776
# because it can't be merged with next_node.
742777
# We will revisit it to make sure this is safe.
@@ -745,8 +780,13 @@ def variables_depend_on(
745780
# need to convert to tuple not to change set size during iteration
746781
for client in tuple(fuseable_clients_clone[next_out]):
747782
if client in visited_nodes:
748-
fuseable_clients_clone[next_out].remove(client)
749-
unfuseable_clients_clone[next_out].add(client)
783+
fuseable_clients_clone.remove_from_key(
784+
next_out, client
785+
)
786+
unfuseable_clients_clone.add_to_key(
787+
next_out, client
788+
)
789+
750790
# next_out must become an input of the subgraph.
751791
# We will revisit any of its clients currently
752792
# in the subgraph to make sure this is safe.
@@ -789,7 +829,7 @@ def variables_depend_on(
789829
sorted(
790830
(
791831
node
792-
for node in fuseable_clients_clone.get(next_out, ())
832+
for node in fuseable_clients_clone.get(next_out)
793833
if node not in visited_nodes
794834
),
795835
key=toposort_index.get, # type: ignore[arg-type]

0 commit comments

Comments
 (0)