Skip to content

Commit adc74fe

Browse files
committed
Cleanup FusionOptimizer code
1 parent e2d94e8 commit adc74fe

File tree

1 file changed

+79
-94
lines changed

1 file changed

+79
-94
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 79 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import defaultdict, deque
66
from collections.abc import Generator, Sequence
77
from functools import cache, reduce
8-
from typing import TypeVar
8+
from typing import Literal
99
from warnings import warn
1010

1111
import pytensor.scalar.basic as ps
@@ -555,8 +555,6 @@ def apply(self, fgraph):
555555
callbacks_before = fgraph.execute_callbacks_times.copy()
556556
callback_before = fgraph.execute_callbacks_time
557557

558-
max_operands = elemwise_max_operands_fct(None)
559-
560558
def find_next_fuseable_subgraph(
561559
fg: FunctionGraph,
562560
) -> Generator[tuple[list[Variable], list[Variable]], None, None]:
@@ -568,8 +566,7 @@ def find_next_fuseable_subgraph(
568566
This generator assumes that such subgraph is replaced by a single
569567
Elemwise Composite before being accessed again in the next iteration.
570568
"""
571-
572-
FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]]
569+
FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
573570
UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
574571

575572
def initialize_fuseable_mappings(
@@ -591,35 +588,31 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
591588
# to ensure the rewrite remains deterministic.
592589
# This is not a problem from unfuseable ones, as they can never
593590
# become part of the graph.
594-
fuseable_clients: FUSEABLE_MAPPING = defaultdict(list)
591+
fuseable_clients: FUSEABLE_MAPPING = defaultdict(set)
595592
unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set)
596593
for out, clients in fg.clients.items():
597-
# Old FunctionGraph nodes remain in the clients dictionary
598-
# even after they are removed by rewrites
599-
if not clients:
600-
continue
601-
602594
out_maybe_fuseable = (
603-
out.owner
595+
out.owner is not None
604596
and isinstance(out.owner.op, Elemwise)
605597
# and not isinstance(out.owner.op.scalar_op, ps.Composite)
606598
and len(out.owner.outputs) == 1
607599
and elemwise_scalar_op_has_c_code(out.owner)
608600
)
609-
for client, _ in clients:
610-
if (
611-
out_maybe_fuseable
612-
and isinstance(client.op, Elemwise)
613-
# and not isinstance(client.op.scalar_op, ps.Composite)
614-
and len(client.outputs) == 1
615-
and out.type.broadcastable
616-
== client.outputs[0].type.broadcastable
617-
and elemwise_scalar_op_has_c_code(client)
618-
):
619-
if client not in fuseable_clients[out]:
620-
fuseable_clients[out].append(client)
621-
else:
622-
unfuseable_clients[out].add(client)
601+
if out_maybe_fuseable:
602+
out_bcast = out.type.broadcastable
603+
for client, _ in clients:
604+
if (
605+
isinstance(client.op, Elemwise)
606+
# and not isinstance(client.op.scalar_op, ps.Composite)
607+
and len(client.outputs) == 1
608+
and out_bcast == client.outputs[0].type.broadcastable
609+
and elemwise_scalar_op_has_c_code(client)
610+
):
611+
fuseable_clients[out].add(client)
612+
else:
613+
unfuseable_clients[out].add(client)
614+
else:
615+
unfuseable_clients[out] = {client for client, _ in clients}
623616

624617
return fuseable_clients, unfuseable_clients
625618

@@ -630,16 +623,6 @@ def find_fuseable_subgraph(
630623
unfuseable_clients: UNFUSEABLE_MAPPING,
631624
toposort_index: dict[Apply, int],
632625
) -> tuple[list[Variable], list[Variable]]:
633-
KT = TypeVar("KT")
634-
VT = TypeVar("VT", list, set)
635-
636-
def shallow_clone_defaultdict(
637-
d: defaultdict[KT, VT],
638-
) -> defaultdict[KT, VT]:
639-
new_dict: defaultdict[KT, VT] = defaultdict(d.default_factory)
640-
new_dict.update({k: v.copy() for k, v in d.items()})
641-
return new_dict
642-
643626
def variables_depend_on(
644627
variables, depend_on, stop_search_at=None
645628
) -> bool:
@@ -657,17 +640,19 @@ def variables_depend_on(
657640
visited_nodes.add(starting_node)
658641
continue
659642

660-
subgraph_inputs: list[Variable] = []
661-
subgraph_outputs: list[Variable] = []
643+
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
644+
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
662645
unfuseable_clients_subgraph: set[Variable] = set()
663646

664647
# Shallow cloning of maps so that they can be manipulated in place
665-
fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients)
666-
unfuseable_clients_clone = shallow_clone_defaultdict(
667-
unfuseable_clients
648+
fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set)
649+
fuseable_clients_clone.update(
650+
{k: v.copy() for k, v in fuseable_clients.items()}
651+
)
652+
unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set)
653+
unfuseable_clients_clone.update(
654+
{k: v.copy() for k, v in unfuseable_clients.items()}
668655
)
669-
670-
fuseable_nodes_to_visit = deque([starting_node])
671656

672657
# We now try to expand as much as possible towards the potentially
673658
# fuseable clients and ancestors to detect the largest possible
@@ -676,6 +661,7 @@ def variables_depend_on(
676661
# some inputs or clients may depend on other nodes of the same
677662
# subgraph via a path that cannot be included in the Composite
678663
# (unfuseable)
664+
fuseable_nodes_to_visit = deque([starting_node])
679665
while fuseable_nodes_to_visit:
680666
next_node = fuseable_nodes_to_visit.popleft()
681667
visited_nodes.add(next_node)
@@ -684,15 +670,14 @@ def variables_depend_on(
684670
# If the output variable of next_node has no fuseable clients
685671
# or has unfuseable clients, then next_node must become an output
686672
# if it is to be fused.
687-
must_become_output = (
688-
next_out not in fuseable_clients_temp
689-
or next_out in unfuseable_clients_clone
690-
)
673+
must_become_output = not fuseable_clients_clone.get(
674+
next_out
675+
) or unfuseable_clients_clone.get(next_out)
691676

692677
# We have backtracked to this node, and it may no longer be a viable output,
693678
# so we remove it and check again as if we had never seen this node
694-
if must_become_output and next_out in subgraph_outputs:
695-
subgraph_outputs.remove(next_out)
679+
if must_become_output:
680+
subgraph_outputs.pop(next_out, None)
696681

697682
required_unfuseable_inputs = [
698683
inp
@@ -744,18 +729,19 @@ def variables_depend_on(
744729
if (
745730
inp.owner in visited_nodes
746731
# next_node could have the same input repeated
747-
and next_node in fuseable_clients_temp[inp]
732+
and next_node in fuseable_clients_clone[inp]
748733
):
749-
fuseable_clients_temp[inp].remove(next_node)
734+
fuseable_clients_clone[inp].remove(next_node)
750735
unfuseable_clients_clone[inp].add(next_node)
751736
# This input must become an output of the subgraph,
752737
# because it can't be merged with next_node.
753738
# We will revisit it to make sure this is safe.
754739
fuseable_nodes_to_visit.appendleft(inp.owner)
755740

756-
for client in fuseable_clients_temp[next_out]:
741+
# need to convert to tuple not to change set size during iteration
742+
for client in tuple(fuseable_clients_clone[next_out]):
757743
if client in visited_nodes:
758-
fuseable_clients_temp[next_out].remove(client)
744+
fuseable_clients_clone[next_out].remove(client)
759745
unfuseable_clients_clone[next_out].add(client)
760746
# next_out must become an input of the subgraph.
761747
# We will revisit any of its clients currently
@@ -771,74 +757,72 @@ def variables_depend_on(
771757
# mappings as if it next_node was part of it.
772758
# Useless inputs will be removed by the useless Composite rewrite
773759
for inp in new_required_unfuseable_inputs:
774-
if inp not in subgraph_inputs:
775-
subgraph_inputs.append(inp)
760+
subgraph_inputs[inp] = None
776761

777762
if must_become_output:
778-
subgraph_outputs.append(next_out)
763+
subgraph_outputs[next_out] = None
779764
unfuseable_clients_subgraph.update(
780765
new_implied_unfuseable_clients
781766
)
782767

783768
# Expand through unvisited fuseable ancestors
784-
for inp in sorted(
785-
(
786-
inp
787-
for inp in next_node.inputs
788-
if (
789-
inp not in required_unfuseable_inputs
790-
and inp.owner not in visited_nodes
791-
)
792-
),
793-
key=lambda inp: toposort_index[inp.owner],
794-
reverse=True,
795-
):
796-
fuseable_nodes_to_visit.appendleft(inp.owner)
769+
fuseable_nodes_to_visit.extendleft(
770+
sorted(
771+
(
772+
inp.owner
773+
for inp in next_node.inputs
774+
if (
775+
inp not in required_unfuseable_inputs
776+
and inp.owner not in visited_nodes
777+
)
778+
),
779+
key=toposort_index.get, # type: ignore[arg-type]
780+
)
781+
)
797782

798783
# Expand through unvisited fuseable clients
799-
for next_node in sorted(
800-
(
801-
node
802-
for node in fuseable_clients_temp.get(next_out, ())
803-
if node not in visited_nodes
804-
),
805-
key=lambda node: toposort_index[node],
806-
):
807-
fuseable_nodes_to_visit.append(next_node)
784+
fuseable_nodes_to_visit.extend(
785+
sorted(
786+
(
787+
node
788+
for node in fuseable_clients_clone.get(next_out, ())
789+
if node not in visited_nodes
790+
),
791+
key=toposort_index.get, # type: ignore[arg-type]
792+
)
793+
)
808794

809795
# Don't return if final subgraph is just the original Elemwise
810796
if len(subgraph_outputs) == 1 and set(
811-
subgraph_outputs[0].owner.inputs
797+
next(iter(subgraph_outputs)).owner.inputs
812798
) == set(subgraph_inputs):
813799
# Update global fuseable mappings
814800
# No input was actually fuseable
815801
for inp in starting_node.inputs:
816-
if starting_node in fuseable_clients.get(inp, ()):
817-
fuseable_clients[inp].remove(starting_node)
818-
unfuseable_clients[inp].add(starting_node)
802+
fuseable_clients[inp].discard(starting_node)
803+
unfuseable_clients[inp].add(starting_node)
819804
# No client was actually fuseable
820805
unfuseable_clients[starting_out].update(
821806
fuseable_clients.pop(starting_out, ())
822807
)
823808
continue
824809

825-
return subgraph_inputs, subgraph_outputs
810+
return list(subgraph_inputs), list(subgraph_outputs)
826811
raise ValueError
827812

828813
def update_fuseable_mappings_after_fg_replace(
829814
*,
830-
fg: FunctionGraph,
831815
visited_nodes: set[Apply],
832816
fuseable_clients: FUSEABLE_MAPPING,
833817
unfuseable_clients: UNFUSEABLE_MAPPING,
834818
starting_nodes: set[Apply],
819+
updated_nodes: set[Apply],
835820
) -> None:
836821
# Find new composite node and dropped intermediate nodes
837822
# by comparing the current fg.apply nodes with the cached
838823
# original nodes
839-
next_nodes = fg.apply_nodes
840-
(new_composite_node,) = next_nodes - starting_nodes
841-
dropped_nodes = starting_nodes - next_nodes
824+
(new_composite_node,) = updated_nodes - starting_nodes
825+
dropped_nodes = starting_nodes - updated_nodes
842826

843827
# Remove intermediate Composite nodes from mappings
844828
for dropped_node in dropped_nodes:
@@ -850,11 +834,11 @@ def update_fuseable_mappings_after_fg_replace(
850834
# Update fuseable information for subgraph inputs
851835
for inp in subgraph_inputs:
852836
if inp in fuseable_clients:
853-
new_fuseable_clients = [
837+
new_fuseable_clients = {
854838
client
855839
for client in fuseable_clients[inp]
856840
if client not in dropped_nodes
857-
]
841+
}
858842
if new_fuseable_clients:
859843
fuseable_clients[inp] = new_fuseable_clients
860844
else:
@@ -898,13 +882,15 @@ def update_fuseable_mappings_after_fg_replace(
898882
# generator. For large models (as in `TestFusion.test_big_fusion`)
899883
# this can provide huge speedups
900884
update_fuseable_mappings_after_fg_replace(
901-
fg=fg,
902885
visited_nodes=visited_nodes,
903886
fuseable_clients=fuseable_clients,
904887
unfuseable_clients=unfuseable_clients,
905888
starting_nodes=starting_nodes,
889+
updated_nodes=fg.apply_nodes,
906890
)
907891

892+
max_operands = elemwise_max_operands_fct(None)
893+
reason = self.__class__.__name__
908894
nb_fused = 0
909895
nb_replacement = 0
910896
for inputs, outputs in find_next_fuseable_subgraph(fgraph):
@@ -923,13 +909,12 @@ def update_fuseable_mappings_after_fg_replace(
923909
assert len(outputs) == len(composite_outputs)
924910
for old_out, composite_out in zip(outputs, composite_outputs):
925911
# Preserve any names on the original outputs
926-
if old_out.name:
927-
composite_out.name = old_out.name
912+
if old_name := old_out.name:
913+
composite_out.name = old_name
928914

929915
starting_nodes = len(fgraph.apply_nodes)
930916
fgraph.replace_all_validate(
931-
list(zip(outputs, composite_outputs, strict=True)),
932-
reason=self.__class__.__name__,
917+
tuple(zip(outputs, composite_outputs)), reason=reason
933918
)
934919
nb_fused += 1
935920
nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1

0 commit comments

Comments
 (0)