Skip to content

Commit f2683c9

Browse files
committed
Try helper classes with bitset
1 parent 9820416 commit f2683c9

File tree

1 file changed

+77
-47
lines changed

1 file changed

+77
-47
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 77 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections.abc import Generator, Sequence
66
from functools import cache, reduce
77
from heapq import heapify, heappop, heappush
8+
from operator import or_
89
from warnings import warn
910

1011
import pytensor.scalar.basic as ps
@@ -15,7 +16,7 @@
1516
from pytensor.graph.basic import Apply, Variable
1617
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
1718
from pytensor.graph.features import ReplaceValidate
18-
from pytensor.graph.fg import FunctionGraph
19+
from pytensor.graph.fg import FunctionGraph, Output
1920
from pytensor.graph.op import Op
2021
from pytensor.graph.rewriting.basic import (
2122
GraphRewriter,
@@ -667,14 +668,14 @@ def __init__(self):
667668
self.queue = queue = []
668669
heapify(queue)
669670

670-
def push(self, node: Apply, toposort_index: int, is_ancestor: bool):
671+
def push(self, node: Apply, node_bitflag: int, is_ancestor: bool):
671672
if is_ancestor:
672-
toposort_index = -toposort_index
673-
heappush(self.queue, (toposort_index, node))
673+
node_bitflag = -node_bitflag
674+
heappush(self.queue, (node_bitflag, node))
674675

675-
def pop(self) -> tuple[Apply, bool]:
676-
toposort_index, node = heappop(self.queue)
677-
return node, toposort_index < 0
676+
def pop(self) -> tuple[Apply, int, bool]:
677+
node_bitflag, node = heappop(self.queue)
678+
return node, node_bitflag < 0
678679

679680
def __bool__(self):
680681
return bool(self.queue)
@@ -684,45 +685,49 @@ class NonConvexError(Exception):
684685

685686
class ConvexSubgraph:
686687
__slots__ = (
687-
"node_ancestors",
688+
"nodes_bitflags",
689+
"ancestors_bitset",
688690
"nodes",
689-
"bitset",
690-
"unfuseable_ancestors",
691-
"unfuseable_clients",
691+
"nodes_bitset",
692+
"unfuseable_ancestors_bitset",
693+
"unfuseable_clients_bitset",
692694
"inputs_and_outputs",
693695
)
694696

695-
def __init__(self, node_ancestors):
696-
self.node_ancestors = node_ancestors
697+
def __init__(self, nodes_bitflags, ancestors_bitset):
698+
self.nodes_bitflags = nodes_bitflags
699+
self.ancestors_bitset = ancestors_bitset
697700
self.nodes = {}
698-
self.bitset = 0
699-
self.unfuseable_ancestors = set()
700-
self.unfuseable_clients = set()
701+
self.nodes_bitset = 0
702+
self.unfuseable_ancestors_bitset = 0
703+
self.unfuseable_clients_bitset = 0
701704
self.inputs_and_outputs = None
702705

703706
def __len__(self):
704707
return len(self.nodes)
705708

706-
def __contains__(self, node: Apply):
707-
return node in self.nodes
709+
def __contains__(self, node: int):
710+
return bool(self.nodes_bitset & self.nodes_bitflags[node])
708711

709712
def add_node(self, node: Apply, is_ancestor: bool):
713+
node_bitflag = self.nodes_bitflags[node]
710714
if is_ancestor:
711-
if node in self.unfuseable_ancestors:
715+
if node_bitflag & self.unfuseable_ancestors_bitset:
712716
raise NonConvexError
713-
elif self.node_ancestors[node] & self.unfuseable_clients:
717+
elif self.ancestors_bitset[node] & self.unfuseable_clients_bitset:
714718
raise NonConvexError
719+
self.nodes_bitset |= node_bitflag
715720
self.nodes[node] = None
716721
self.inputs_and_outputs = None # clear cache
717722

718723
def add_unfuseable_ancestor(self, ancestor: Apply):
719724
# If an ancestor is unfuseable, so are all its ancestors
720-
self.unfuseable_ancestors |= self.node_ancestors[ancestor]
725+
self.unfuseable_ancestors_bitset |= self.ancestors_bitset[ancestor]
721726

722727
def add_unfuseable_client(self, client: Apply):
723728
# If a client is unfuseable, so are all its clients, but we don't need to keep track of those
724729
# Any downstream client will also depend on this unfuseable client and will be rejected when visited
725-
self.unfuseable_clients.add(client)
730+
self.unfuseable_clients_bitset |= self.nodes_bitflags[client]
726731

727732
def get_inputs_and_outputs(self):
728733
if self.inputs_and_outputs is not None:
@@ -752,37 +757,40 @@ def get_inputs_and_outputs(self):
752757
return subgraph_inputs, subgraph_outputs
753758

754759
class SortedSubgraphCollection:
755-
__slots__ = ("subgraphs", "nodes")
760+
__slots__ = ("subgraphs", "nodes_bitset")
756761

757762
def __init__(self):
758763
self.subgraphs: list[
759764
tuple[int, tuple[tuple[Variable], tuple[Variable]]]
760765
] = []
761-
self.nodes = {}
766+
self.nodes_bitset = 0
762767

763-
def __contains__(self, node: Apply):
764-
return node in self.nodes
768+
def __contains__(self, node_bitflag: int):
769+
return bool(node_bitflag & self.nodes_bitset)
765770

766771
def insert_subgraph(self, subgraph: ConvexSubgraph):
767772
# Usually new subgraphs don't depend on previous subgraphs, so we can just append them at the end
768773
# But in some cases they can, so we need to insert at the right position.
769-
subgraph_unfuseable_ancestors = subgraph.unfuseable_ancestors
770-
if subgraph_unfuseable_ancestors.isdisjoint(self.nodes):
774+
subgraph_unfuseable_ancestors_bitset = (
775+
subgraph.unfuseable_ancestors_bitset
776+
)
777+
if not (subgraph_unfuseable_ancestors_bitset & self.nodes_bitset):
771778
self.subgraphs.append(subgraph)
772779
else:
773780
# Iterate from the end, removing the bitsets of each previous subgraphs until our current subgraph
774781
# no longer depends on what's left. This tells us where to insert the current subgraph.
775-
remaining_nodes = set(self.nodes)
782+
remaining_nodes_bitset = self.nodes_bitset
776783
for index, other_subgraph in enumerate(
777784
reversed(self.subgraphs)
778785
):
779-
remaining_nodes.difference_update(other_subgraph.nodes)
780-
if subgraph_unfuseable_ancestors.isdisjoint(
781-
remaining_nodes
786+
remaining_nodes_bitset &= ~other_subgraph.nodes_bitset
787+
if not (
788+
subgraph_unfuseable_ancestors_bitset
789+
& remaining_nodes_bitset
782790
):
783791
break
784792
self.subgraphs.insert(-(index + 1), subgraph)
785-
self.nodes |= subgraph.nodes
793+
self.nodes_bitset |= subgraph.nodes_bitset
786794

787795
def __iter__(self):
788796
yield from self.subgraphs
@@ -792,32 +800,54 @@ def __iter__(self):
792800
if not fuseable_clients:
793801
return None
794802

795-
toposort_idx = {
796-
node: idx for idx, node in enumerate(fg.toposort(), start=1)
797-
}
798-
node_ancestors = {None: frozenset()}
799-
for node in toposort_idx:
800-
node_ancestors[node] = frozenset.union(
801-
*(node_ancestors[inp.owner] for inp in node.inputs), {node}
803+
# Create a bitset of ancestors for each node.
804+
# Each node is represented by a bit flag of it's position in the toposort
805+
# With two variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c,
806+
# the ancestors bit flags would be {A: 0b001, B: 0b010, C: 0b100}
807+
# and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111}
808+
# This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND
809+
# For example, to ask if B is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[B] != 0`
810+
# We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do
811+
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
812+
fg_clients = fgraph.clients
813+
nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())}
814+
# Root variables have `None` as owner, which we can handle with a bitset of 0 for `None`
815+
ancestors_bitset = {None: 0}
816+
for node, node_bitflag in nodes_bitflags.items():
817+
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit
818+
ancestors_bitset[node] = reduce(
819+
or_,
820+
(ancestors_bitset[inp.owner] for inp in node.inputs),
821+
node_bitflag,
802822
)
823+
# handle root and leaf nodes gracefully
824+
# Root variables have `None` as owner, which we can handle with a bitflag of 0 for `None`
825+
nodes_bitflags[None] = 0
826+
# Nothing ever depends on output nodes, so just use a new bit for all
827+
out_bitflag = 1 << len(nodes_bitflags)
828+
for out in fg.outputs:
829+
for client, _ in fg_clients[out]:
830+
if isinstance(client.op, Output):
831+
nodes_bitflags[client] = out_bitflag
803832

804-
fg_clients = fgraph.clients
805833
sorted_subgraphs = SortedSubgraphCollection()
806834

807835
# Start exploring from candidate sink nodes (backwards)
808836
# These are Elemwise nodes with a C-implementation, that are not part of another subgraph
809837
# And have no other fuseable clients (i.e., are sinks)
810-
for starting_node, starting_idx in reversed(toposort_idx.items()):
838+
for starting_node, starting_bitflag in reversed(nodes_bitflags.items()):
811839
if (
812-
starting_node in sorted_subgraphs
840+
starting_bitflag in sorted_subgraphs
813841
or not fuseable_clients.is_sink_node(starting_node)
814842
):
815843
continue
816844

817-
subgraph = ConvexSubgraph(node_ancestors)
845+
subgraph = ConvexSubgraph(nodes_bitflags, ancestors_bitset)
818846

819847
fuseable_nodes_queue = SortedFuseableNodesQueue()
820-
fuseable_nodes_queue.push(starting_node, starting_idx, is_ancestor=True)
848+
fuseable_nodes_queue.push(
849+
starting_node, starting_bitflag, is_ancestor=True
850+
)
821851
while fuseable_nodes_queue:
822852
node, is_ancestor = fuseable_nodes_queue.pop()
823853

@@ -841,7 +871,7 @@ def __iter__(self):
841871
if node in fuseable_clients[ancestor_node]:
842872
fuseable_nodes_queue.push(
843873
ancestor_node,
844-
toposort_idx[ancestor_node],
874+
nodes_bitflags[ancestor_node],
845875
is_ancestor=True,
846876
)
847877
else:
@@ -854,7 +884,7 @@ def __iter__(self):
854884
if client_node in next_fuseable_clients:
855885
fuseable_nodes_queue.push(
856886
client_node,
857-
toposort_idx[client_node],
887+
nodes_bitflags[client_node],
858888
is_ancestor=False,
859889
)
860890
else:

0 commit comments

Comments
 (0)