Skip to content

Commit 824af00

Browse files
committed
Cleanup FusionOptimizer code
1 parent 9baa8a4 commit 824af00

File tree

1 file changed

+76
-88
lines changed

1 file changed

+76
-88
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 76 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import defaultdict, deque
66
from collections.abc import Generator, Sequence
77
from functools import cache, reduce
8-
from typing import TypeVar
8+
from typing import Literal
99
from warnings import warn
1010

1111
import pytensor.scalar.basic as ps
@@ -568,8 +568,7 @@ def find_next_fuseable_subgraph(
568568
This generator assumes that such subgraph is replaced by a single
569569
Elemwise Composite before being accessed again in the next iteration.
570570
"""
571-
572-
FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]]
571+
FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
573572
UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
574573

575574
def initialize_fuseable_mappings(
@@ -591,35 +590,33 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
591590
# to ensure the rewrite remains deterministic.
592591
# This is not a problem from unfuseable ones, as they can never
593592
# become part of the graph.
594-
fuseable_clients: FUSEABLE_MAPPING = defaultdict(list)
593+
fuseable_clients: FUSEABLE_MAPPING = defaultdict(set)
595594
unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set)
596595
for out, clients in fg.clients.items():
597-
# Old FunctionGraph nodes remain in the clients dictionary
598-
# even after they are removed by rewrites
599-
if not clients:
600-
continue
601-
602596
out_maybe_fuseable = (
603-
out.owner
597+
out.owner is not None
604598
and isinstance(out.owner.op, Elemwise)
605599
# and not isinstance(out.owner.op.scalar_op, ps.Composite)
606600
and len(out.owner.outputs) == 1
607601
and elemwise_scalar_op_has_c_code(out.owner)
608602
)
609-
for client, _ in clients:
610-
if (
611-
out_maybe_fuseable
612-
and isinstance(client.op, Elemwise)
613-
# and not isinstance(client.op.scalar_op, ps.Composite)
614-
and len(client.outputs) == 1
615-
and out.type.broadcastable
616-
== client.outputs[0].type.broadcastable
617-
and elemwise_scalar_op_has_c_code(client)
618-
):
619-
if client not in fuseable_clients[out]:
620-
fuseable_clients[out].append(client)
621-
else:
622-
unfuseable_clients[out].add(client)
603+
if out_maybe_fuseable:
604+
out_bcast = (
605+
out.type.broadcastable if out_maybe_fuseable else None
606+
)
607+
for client, _ in clients:
608+
if (
609+
isinstance(client.op, Elemwise)
610+
# and not isinstance(client.op.scalar_op, ps.Composite)
611+
and len(client.outputs) == 1
612+
and out_bcast == client.outputs[0].type.broadcastable
613+
and elemwise_scalar_op_has_c_code(client)
614+
):
615+
fuseable_clients[out].add(client)
616+
else:
617+
unfuseable_clients[out].add(client)
618+
else:
619+
unfuseable_clients[out] = {client for client, _ in clients}
623620

624621
return fuseable_clients, unfuseable_clients
625622

@@ -630,16 +627,6 @@ def find_fuseable_subgraph(
630627
unfuseable_clients: UNFUSEABLE_MAPPING,
631628
toposort_index: dict[Apply, int],
632629
) -> tuple[list[Variable], list[Variable]]:
633-
KT = TypeVar("KT")
634-
VT = TypeVar("VT", list, set)
635-
636-
def shallow_clone_defaultdict(
637-
d: defaultdict[KT, VT],
638-
) -> defaultdict[KT, VT]:
639-
new_dict: defaultdict[KT, VT] = defaultdict(d.default_factory)
640-
new_dict.update({k: v.copy() for k, v in d.items()})
641-
return new_dict
642-
643630
def variables_depend_on(
644631
variables, depend_on, stop_search_at=None
645632
) -> bool:
@@ -657,17 +644,19 @@ def variables_depend_on(
657644
visited_nodes.add(starting_node)
658645
continue
659646

660-
subgraph_inputs: list[Variable] = []
661-
subgraph_outputs: list[Variable] = []
647+
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
648+
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
662649
unfuseable_clients_subgraph: set[Variable] = set()
663650

664651
# Shallow cloning of maps so that they can be manipulated in place
665-
fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients)
666-
unfuseable_clients_clone = shallow_clone_defaultdict(
667-
unfuseable_clients
652+
fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set)
653+
fuseable_clients_clone.update(
654+
{k: v.copy() for k, v in fuseable_clients.items()}
655+
)
656+
unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set)
657+
unfuseable_clients_clone.update(
658+
{k: v.copy() for k, v in unfuseable_clients.items()}
668659
)
669-
670-
fuseable_nodes_to_visit = deque([starting_node])
671660

672661
# We now try to expand as much as possible towards the potentially
673662
# fuseable clients and ancestors to detect the largest possible
@@ -676,6 +665,7 @@ def variables_depend_on(
676665
# some inputs or clients may depend on other nodes of the same
677666
# subgraph via a path that cannot be included in the Composite
678667
# (unfuseable)
668+
fuseable_nodes_to_visit = deque([starting_node])
679669
while fuseable_nodes_to_visit:
680670
next_node = fuseable_nodes_to_visit.popleft()
681671
visited_nodes.add(next_node)
@@ -684,15 +674,14 @@ def variables_depend_on(
684674
# If the output variable of next_node has no fuseable clients
685675
# or has unfuseable clients, then next_node must become an output
686676
# if it is to be fused.
687-
must_become_output = (
688-
next_out not in fuseable_clients_temp
689-
or next_out in unfuseable_clients_clone
690-
)
677+
must_become_output = not fuseable_clients_clone.get(
678+
next_out
679+
) or unfuseable_clients_clone.get(next_out)
691680

692681
# We have backtracked to this node, and it may no longer be a viable output,
693682
# so we remove it and check again as if we had never seen this node
694-
if must_become_output and next_out in subgraph_outputs:
695-
subgraph_outputs.remove(next_out)
683+
if must_become_output:
684+
subgraph_outputs.pop(next_out, None)
696685

697686
required_unfuseable_inputs = [
698687
inp
@@ -744,18 +733,19 @@ def variables_depend_on(
744733
if (
745734
inp.owner in visited_nodes
746735
# next_node could have the same input repeated
747-
and next_node in fuseable_clients_temp[inp]
736+
and next_node in fuseable_clients_clone[inp]
748737
):
749-
fuseable_clients_temp[inp].remove(next_node)
738+
fuseable_clients_clone[inp].remove(next_node)
750739
unfuseable_clients_clone[inp].add(next_node)
751740
# This input must become an output of the subgraph,
752741
# because it can't be merged with next_node.
753742
# We will revisit it to make sure this is safe.
754743
fuseable_nodes_to_visit.appendleft(inp.owner)
755744

756-
for client in fuseable_clients_temp[next_out]:
745+
# need to convert to tuple not to change set size during iteration
746+
for client in tuple(fuseable_clients_clone[next_out]):
757747
if client in visited_nodes:
758-
fuseable_clients_temp[next_out].remove(client)
748+
fuseable_clients_clone[next_out].remove(client)
759749
unfuseable_clients_clone[next_out].add(client)
760750
# next_out must become an input of the subgraph.
761751
# We will revisit any of its clients currently
@@ -771,74 +761,72 @@ def variables_depend_on(
771761
# mappings as if it next_node was part of it.
772762
# Useless inputs will be removed by the useless Composite rewrite
773763
for inp in new_required_unfuseable_inputs:
774-
if inp not in subgraph_inputs:
775-
subgraph_inputs.append(inp)
764+
subgraph_inputs[inp] = None
776765

777766
if must_become_output:
778-
subgraph_outputs.append(next_out)
767+
subgraph_outputs[next_out] = None
779768
unfuseable_clients_subgraph.update(
780769
new_implied_unfuseable_clients
781770
)
782771

783772
# Expand through unvisited fuseable ancestors
784-
for inp in sorted(
785-
(
786-
inp
787-
for inp in next_node.inputs
788-
if (
789-
inp not in required_unfuseable_inputs
790-
and inp.owner not in visited_nodes
791-
)
792-
),
793-
key=lambda inp: toposort_index[inp.owner],
794-
reverse=True,
795-
):
796-
fuseable_nodes_to_visit.appendleft(inp.owner)
773+
fuseable_nodes_to_visit.extendleft(
774+
sorted(
775+
(
776+
inp.owner
777+
for inp in next_node.inputs
778+
if (
779+
inp not in required_unfuseable_inputs
780+
and inp.owner not in visited_nodes
781+
)
782+
),
783+
key=toposort_index.get, # type: ignore[arg-type]
784+
)
785+
)
797786

798787
# Expand through unvisited fuseable clients
799-
for next_node in sorted(
800-
(
801-
node
802-
for node in fuseable_clients_temp.get(next_out, ())
803-
if node not in visited_nodes
804-
),
805-
key=lambda node: toposort_index[node],
806-
):
807-
fuseable_nodes_to_visit.append(next_node)
788+
fuseable_nodes_to_visit.extend(
789+
sorted(
790+
(
791+
node
792+
for node in fuseable_clients_clone.get(next_out, ())
793+
if node not in visited_nodes
794+
),
795+
key=toposort_index.get, # type: ignore[arg-type]
796+
)
797+
)
808798

809799
# Don't return if final subgraph is just the original Elemwise
810800
if len(subgraph_outputs) == 1 and set(
811-
subgraph_outputs[0].owner.inputs
801+
next(iter(subgraph_outputs)).owner.inputs
812802
) == set(subgraph_inputs):
813803
# Update global fuseable mappings
814804
# No input was actually fuseable
815805
for inp in starting_node.inputs:
816-
if starting_node in fuseable_clients.get(inp, ()):
817-
fuseable_clients[inp].remove(starting_node)
818-
unfuseable_clients[inp].add(starting_node)
806+
fuseable_clients[inp].discard(starting_node)
807+
unfuseable_clients[inp].add(starting_node)
819808
# No client was actually fuseable
820809
unfuseable_clients[starting_out].update(
821810
fuseable_clients.pop(starting_out, ())
822811
)
823812
continue
824813

825-
return subgraph_inputs, subgraph_outputs
814+
return list(subgraph_inputs), list(subgraph_outputs)
826815
raise ValueError
827816

828817
def update_fuseable_mappings_after_fg_replace(
829818
*,
830-
fg: FunctionGraph,
831819
visited_nodes: set[Apply],
832820
fuseable_clients: FUSEABLE_MAPPING,
833821
unfuseable_clients: UNFUSEABLE_MAPPING,
834822
starting_nodes: set[Apply],
823+
updated_nodes: set[Apply],
835824
) -> None:
836825
# Find new composite node and dropped intermediate nodes
837826
# by comparing the current fg.apply nodes with the cached
838827
# original nodes
839-
next_nodes = fg.apply_nodes
840-
(new_composite_node,) = next_nodes - starting_nodes
841-
dropped_nodes = starting_nodes - next_nodes
828+
(new_composite_node,) = updated_nodes - starting_nodes
829+
dropped_nodes = starting_nodes - updated_nodes
842830

843831
# Remove intermediate Composite nodes from mappings
844832
for dropped_node in dropped_nodes:
@@ -850,11 +838,11 @@ def update_fuseable_mappings_after_fg_replace(
850838
# Update fuseable information for subgraph inputs
851839
for inp in subgraph_inputs:
852840
if inp in fuseable_clients:
853-
new_fuseable_clients = [
841+
new_fuseable_clients = {
854842
client
855843
for client in fuseable_clients[inp]
856844
if client not in dropped_nodes
857-
]
845+
}
858846
if new_fuseable_clients:
859847
fuseable_clients[inp] = new_fuseable_clients
860848
else:
@@ -898,11 +886,11 @@ def update_fuseable_mappings_after_fg_replace(
898886
# generator. For large models (as in `TestFusion.test_big_fusion`)
899887
# this can provide huge speedups
900888
update_fuseable_mappings_after_fg_replace(
901-
fg=fg,
902889
visited_nodes=visited_nodes,
903890
fuseable_clients=fuseable_clients,
904891
unfuseable_clients=unfuseable_clients,
905892
starting_nodes=starting_nodes,
893+
updated_nodes=fg.apply_nodes,
906894
)
907895

908896
nb_fused = 0

0 commit comments

Comments
 (0)