Skip to content

Commit 5b9baec

Browse files
committed
Regular sets instead of bitsets
1 parent eb010b7 commit 5b9baec

File tree

1 file changed

+46
-72
lines changed

1 file changed

+46
-72
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 46 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from collections.abc import Generator, Sequence
66
from functools import cache, reduce
77
from heapq import heapify, heappop, heappush
8-
from operator import or_
98
from warnings import warn
109

1110
import pytensor.scalar.basic as ps
@@ -16,7 +15,7 @@
1615
from pytensor.graph.basic import Apply, Variable
1716
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
1817
from pytensor.graph.features import ReplaceValidate
19-
from pytensor.graph.fg import FunctionGraph, Output
18+
from pytensor.graph.fg import FunctionGraph
2019
from pytensor.graph.op import Op
2120
from pytensor.graph.rewriting.basic import (
2221
GraphRewriter,
@@ -621,48 +620,28 @@ def elemwise_scalar_op_has_c_code(
621620
if not fuseable_clients:
622621
return None
623622

624-
# Create a bitset of ancestors for each node.
625-
# Each node is represented by a bit flag of it's position in the toposort
626-
# With two variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c,
627-
# the ancestors bit flags would be {A: 0b001, B: 0b010, C: 0b100}
628-
# and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111}
629-
# This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND
630-
# For example, to ask if B is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[B] != 0`
631-
# We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do
632-
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
633-
nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())}
634-
ancestors_bitset = {
635-
None: 0
636-
} # Root variables have `None` as owner, which we can handle with a bitset of 0 for `None`
637-
for node, node_bitflag in nodes_bitflags.items():
638-
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit
639-
ancestors_bitset[node] = reduce(
640-
or_,
641-
(ancestors_bitset[inp.owner] for inp in node.inputs),
642-
node_bitflag,
623+
toposort_idx = {
624+
node: idx for idx, node in enumerate(fg.toposort(), start=1)
625+
}
626+
node_ancestors = {None: frozenset()}
627+
for node in toposort_idx:
628+
node_ancestors[node] = frozenset.union(
629+
*(node_ancestors[inp.owner] for inp in node.inputs), {node}
643630
)
644-
# handle root and leaf nodes gracefully
645-
nodes_bitflags[None] = (
646-
0 # Root variables have `None` as owner, which we can handle with a bitflag of 0 for `None`
647-
)
648-
out_bitflag = 1 << len(
649-
nodes_bitflags
650-
) # Nothing ever depends on output nodes, so just use a new bit for all
651-
for out in fg.outputs:
652-
for client, _ in fg_clients[out]:
653-
if isinstance(client.op, Output):
654-
nodes_bitflags[client] = out_bitflag
655631

656632
sorted_subgraphs: list[
657633
tuple[int, tuple[tuple[Variable], tuple[Variable]]]
658634
] = []
659-
all_subgraphs_bitset = 0
635+
subgraph_set = set()
636+
unfuseable_ancestors_set = set()
637+
unfuseable_clients_set = set()
638+
all_subgraphs_set = set()
660639
# Start exploring from candidate sink nodes (backwards)
661640
# These are Elemwise nodes with a C-implementation, that are not part of another subgraph
662641
# And have no other fuseable clients (i.e., are sinks)
663-
for starting_node, starting_bitflag in reversed(nodes_bitflags.items()):
642+
for starting_node, starting_index in reversed(toposort_idx.items()):
664643
if (
665-
starting_bitflag & all_subgraphs_bitset
644+
starting_node in all_subgraphs_set
666645
or starting_node not in candidate_nodes
667646
):
668647
continue
@@ -676,7 +655,7 @@ def elemwise_scalar_op_has_c_code(
676655
# For ancestors, we want to visit the later nodes first (those that have more dependencies)
677656
# whereas for clients we want to visit earlier nodes first (those that have fewer dependencies)
678657
# We negate the bitflag for ancestors to achieve this ordering
679-
fuseables_nodes_queue = [(-starting_bitflag, starting_node)]
658+
fuseables_nodes_queue = [(-starting_index, starting_node)]
680659
heapify(fuseables_nodes_queue)
681660

682661
# We keep 3 bitsets during the exploration:
@@ -687,37 +666,35 @@ def elemwise_scalar_op_has_c_code(
687666
# in which case we can't fuse it. If we can fuse it, we then add its unfuseable ancestors/clients to the respective bitsets
688667
# and add its fuseable ancestors/clients to the queue to explore later. This approach requires a visit in the order described above.
689668
# Otherwise, we need to recompute target bitsets in every iteration and/or backtrack.
690-
subgraph_nodes = []
691-
subgraph_bitset = 0
692-
unfuseable_ancestors_bitset = 0
693-
unfuseable_clients_bitset = 0
669+
subgraph_set.clear()
670+
unfuseable_ancestors_set.clear()
671+
unfuseable_clients_set.clear()
694672

695673
# print(f"\nStarting new subgraph exploration from {starting_node}")
696674
while fuseables_nodes_queue:
697-
node_bitflag, node = heappop(fuseables_nodes_queue)
698-
is_ancestor = node_bitflag < 0
675+
node_idx, node = heappop(fuseables_nodes_queue)
676+
is_ancestor = node_idx < 0
699677
if is_ancestor:
700-
node_bitflag = -node_bitflag
678+
node_idx = -node_idx
701679
# print(f"\t > Visiting {'ancestor' if is_ancestor else 'client'} {next_node}")
702680

703-
if node_bitflag & subgraph_bitset:
681+
if node in subgraph_set:
704682
# Already part of the subgraph
705683
# print("\t - already in subgraph")
706684
continue
707685

708686
if is_ancestor:
709-
if node_bitflag & unfuseable_ancestors_bitset:
687+
if node in unfuseable_ancestors_set:
710688
# An unfuseable ancestor depends on this node, can't fuse
711689
# print("\t failed - unfuseable ancestor depends on it")
712690
continue
713-
elif ancestors_bitset[node] & unfuseable_clients_bitset:
691+
elif not node_ancestors[node].isdisjoint(unfuseable_clients_set):
714692
# This node depends on an unfuseable client, can't fuse
715693
# print("\t failed - depends on unfuseable client")
716694
continue
717695

718696
# print("\t succeeded - adding to subgraph")
719-
subgraph_nodes.append(node)
720-
subgraph_bitset |= node_bitflag
697+
subgraph_set.add(node)
721698

722699
# Expand through ancestors and client nodes
723700
# A node can either be:
@@ -726,57 +703,56 @@ def elemwise_scalar_op_has_c_code(
726703
# - unfuseable (add to respective unfuseable bitset)
727704
for ancestor in node.inputs:
728705
ancestor_node = ancestor.owner
729-
ancestor_bitflag = nodes_bitflags[ancestor_node]
730-
if ancestor_bitflag & subgraph_bitset:
706+
if ancestor_node in subgraph_set:
731707
continue
732708
if node in fuseable_clients.get(ancestor_node, ()):
733709
heappush(
734710
fuseables_nodes_queue,
735-
(-ancestor_bitflag, ancestor_node),
711+
(-toposort_idx[ancestor_node], ancestor_node),
736712
)
737713
else:
738714
# If an ancestor is unfuseable, so are all its ancestors
739-
unfuseable_ancestors_bitset |= ancestors_bitset[
740-
ancestor_node
741-
]
715+
unfuseable_ancestors_set |= node_ancestors[ancestor_node]
742716

743717
next_fuseable_clients = fuseable_clients.get(node, ())
744718
for client, _ in fg_clients[node.outputs[0]]:
745-
client_bitflag = nodes_bitflags[client]
746-
if client_bitflag & subgraph_bitset:
719+
if client in subgraph_set:
747720
continue
748721
if client in next_fuseable_clients:
749-
heappush(fuseables_nodes_queue, (client_bitflag, client))
722+
heappush(
723+
fuseables_nodes_queue, (toposort_idx[client], client)
724+
)
750725
else:
751726
# If a client is unfuseable, so are all its clients, but we don't need to keep track of those
752727
# Any downstream client will also depend on this unfuseable client and will be rejected when visited
753-
unfuseable_clients_bitset |= client_bitflag
728+
unfuseable_clients_set.add(client)
754729

755730
# Finished exploring this subgraph
756-
all_subgraphs_bitset |= subgraph_bitset
731+
all_subgraphs_set |= subgraph_set
757732

758-
if subgraph_bitset == starting_bitflag:
733+
if len(subgraph_set) == 1:
759734
# No fusion possible, single node subgraph
760735
continue
761736

762737
# Find out inputs/outputs of subgraph_nodes
763-
not_subgraph_bitset = ~subgraph_bitset
738+
# not_subgraph_bitset = ~subgraph_set
764739
# Use a dict to deduplicate while preserving order
740+
subgraph_nodes = sorted(subgraph_set, key=toposort_idx.get)
765741
subgraph_inputs = tuple(
766742
dict.fromkeys(
767743
inp
768744
for node in subgraph_nodes
769745
for inp in node.inputs
770746
if (ancestor_node := inp.owner) is None
771-
or nodes_bitflags[ancestor_node] & not_subgraph_bitset
747+
or ancestor_node not in subgraph_set
772748
)
773749
)
774750

775751
subgraph_outputs = tuple(
776752
node.outputs[0]
777753
for node in subgraph_nodes
778754
if any(
779-
nodes_bitflags[client] & not_subgraph_bitset
755+
client not in subgraph_set
780756
for client, _ in fg_clients[node.outputs[0]]
781757
)
782758
)
@@ -786,26 +762,24 @@ def elemwise_scalar_op_has_c_code(
786762

787763
# Usually new subgraphs don't depend on previous subgraphs, so we can just append them at the end
788764
# But in some cases they can, so we need to insert at the right position.
789-
if not (unfuseable_ancestors_bitset & all_subgraphs_bitset):
765+
if not (unfuseable_ancestors_set & all_subgraphs_set):
790766
sorted_subgraphs.append(
791-
(subgraph_bitset, (subgraph_inputs, subgraph_outputs))
767+
(subgraph_set, (subgraph_inputs, subgraph_outputs))
792768
)
793769
else:
794770
# Iterate from the end, removing the bitsets of each previous subgraphs until our current subgraph
795771
# no longer depends on what's left. This tells us where to insert the current subgraph.
796-
remaining_subgraphs_bitset = all_subgraphs_bitset
797-
for index, (other_subgraph_bitset, _) in enumerate(
772+
remaining_subgraphs_bitset = all_subgraphs_set.copy()
773+
for index, (other_subgraph_set, _) in enumerate(
798774
reversed(sorted_subgraphs)
799775
):
800-
remaining_subgraphs_bitset &= ~other_subgraph_bitset
801-
if not (
802-
unfuseable_ancestors_bitset & remaining_subgraphs_bitset
803-
):
776+
remaining_subgraphs_bitset.difference_update(other_subgraph_set)
777+
if not (unfuseable_ancestors_set & remaining_subgraphs_bitset):
804778
break
805779

806780
sorted_subgraphs.insert(
807781
-(index + 1),
808-
(subgraph_bitset, (subgraph_inputs, subgraph_outputs)),
782+
(subgraph_set, (subgraph_inputs, subgraph_outputs)),
809783
)
810784

811785
# Update fuseable clients, inputs can no longer be fused with graph variables

0 commit comments

Comments
 (0)