55from collections .abc import Generator , Sequence
66from functools import cache , reduce
77from heapq import heapify , heappop , heappush
8+ from operator import or_
89from warnings import warn
910
1011import pytensor .scalar .basic as ps
1516from pytensor .graph .basic import Apply , Variable
1617from pytensor .graph .destroyhandler import DestroyHandler , inplace_candidates
1718from pytensor .graph .features import ReplaceValidate
18- from pytensor .graph .fg import FunctionGraph
19+ from pytensor .graph .fg import FunctionGraph , Output
1920from pytensor .graph .op import Op
2021from pytensor .graph .rewriting .basic import (
2122 GraphRewriter ,
@@ -667,14 +668,14 @@ def __init__(self):
667668 self .queue = queue = []
668669 heapify (queue )
669670
670- def push (self , node : Apply , toposort_index : int , is_ancestor : bool ):
671+ def push (self , node : Apply , node_bitflag : int , is_ancestor : bool ):
671672 if is_ancestor :
672- toposort_index = - toposort_index
673- heappush (self .queue , (toposort_index , node ))
673+ node_bitflag = - node_bitflag
674+ heappush (self .queue , (node_bitflag , node ))
674675
675- def pop (self ) -> tuple [Apply , bool ]:
676- toposort_index , node = heappop (self .queue )
677- return node , toposort_index < 0
676+ def pop (self ) -> tuple [Apply , int , bool ]:
677+ node_bitflag , node = heappop (self .queue )
678+ return node , node_bitflag < 0
678679
679680 def __bool__ (self ):
680681 return bool (self .queue )
@@ -684,45 +685,49 @@ class NonConvexError(Exception):
684685
685686 class ConvexSubgraph :
686687 __slots__ = (
687- "node_ancestors" ,
688+ "nodes_bitflags" ,
689+ "ancestors_bitset" ,
688690 "nodes" ,
689- "bitset " ,
690- "unfuseable_ancestors " ,
691- "unfuseable_clients " ,
691+ "nodes_bitset " ,
692+ "unfuseable_ancestors_bitset " ,
693+ "unfuseable_clients_bitset " ,
692694 "inputs_and_outputs" ,
693695 )
694696
695- def __init__ (self , node_ancestors ):
696- self .node_ancestors = node_ancestors
697+ def __init__ (self , nodes_bitflags , ancestors_bitset ):
698+ self .nodes_bitflags = nodes_bitflags
699+ self .ancestors_bitset = ancestors_bitset
697700 self .nodes = {}
698- self .bitset = 0
699- self .unfuseable_ancestors = set ()
700- self .unfuseable_clients = set ()
701+ self .nodes_bitset = 0
702+ self .unfuseable_ancestors_bitset = 0
703+ self .unfuseable_clients_bitset = 0
701704 self .inputs_and_outputs = None
702705
703706 def __len__ (self ):
704707 return len (self .nodes )
705708
706- def __contains__ (self , node : Apply ):
707- return node in self .nodes
709+ def __contains__ (self , node : int ):
710+ return bool ( self . nodes_bitset & self .nodes_bitflags [ node ])
708711
709712 def add_node (self , node : Apply , is_ancestor : bool ):
713+ node_bitflag = self .nodes_bitflags [node ]
710714 if is_ancestor :
711- if node in self .unfuseable_ancestors :
715+ if node_bitflag & self .unfuseable_ancestors_bitset :
712716 raise NonConvexError
713- elif self .node_ancestors [node ] & self .unfuseable_clients :
717+ elif self .ancestors_bitset [node ] & self .unfuseable_clients_bitset :
714718 raise NonConvexError
719+ self .nodes_bitset |= node_bitflag
715720 self .nodes [node ] = None
716721 self .inputs_and_outputs = None # clear cache
717722
718723 def add_unfuseable_ancestor (self , ancestor : Apply ):
719724 # If an ancestor is unfuseable, so are all its ancestors
720- self .unfuseable_ancestors |= self .node_ancestors [ancestor ]
725+ self .unfuseable_ancestors_bitset |= self .ancestors_bitset [ancestor ]
721726
722727 def add_unfuseable_client (self , client : Apply ):
723728 # If a client is unfuseable, so are all its clients, but we don't need to keep track of those
724729 # Any downstream client will also depend on this unfuseable client and will be rejected when visited
725- self .unfuseable_clients . add ( client )
730+ self .unfuseable_clients_bitset |= self . nodes_bitflags [ client ]
726731
727732 def get_inputs_and_outputs (self ):
728733 if self .inputs_and_outputs is not None :
@@ -752,37 +757,40 @@ def get_inputs_and_outputs(self):
752757 return subgraph_inputs , subgraph_outputs
753758
754759 class SortedSubgraphCollection :
755- __slots__ = ("subgraphs" , "nodes " )
760+ __slots__ = ("subgraphs" , "nodes_bitset " )
756761
757762 def __init__ (self ):
758763 self .subgraphs : list [
759764 tuple [int , tuple [tuple [Variable ], tuple [Variable ]]]
760765 ] = []
761- self .nodes = {}
766+ self .nodes_bitset = 0
762767
763- def __contains__ (self , node : Apply ):
764- return node in self .nodes
768+ def __contains__ (self , node_bitflag : int ):
769+ return bool ( node_bitflag & self .nodes_bitset )
765770
766771 def insert_subgraph (self , subgraph : ConvexSubgraph ):
767772 # Usually new subgraphs don't depend on previous subgraphs, so we can just append them at the end
768773 # But in some cases they can, so we need to insert at the right position.
769- subgraph_unfuseable_ancestors = subgraph .unfuseable_ancestors
770- if subgraph_unfuseable_ancestors .isdisjoint (self .nodes ):
774+ subgraph_unfuseable_ancestors_bitset = (
775+ subgraph .unfuseable_ancestors_bitset
776+ )
777+ if not (subgraph_unfuseable_ancestors_bitset & self .nodes_bitset ):
771778 self .subgraphs .append (subgraph )
772779 else :
773780 # Iterate from the end, removing the bitsets of each previous subgraphs until our current subgraph
774781 # no longer depends on what's left. This tells us where to insert the current subgraph.
775- remaining_nodes = set ( self .nodes )
782+ remaining_nodes_bitset = self .nodes_bitset
776783 for index , other_subgraph in enumerate (
777784 reversed (self .subgraphs )
778785 ):
779- remaining_nodes .difference_update (other_subgraph .nodes )
780- if subgraph_unfuseable_ancestors .isdisjoint (
781- remaining_nodes
786+ remaining_nodes_bitset &= ~ other_subgraph .nodes_bitset
787+ if not (
788+ subgraph_unfuseable_ancestors_bitset
789+ & remaining_nodes_bitset
782790 ):
783791 break
784792 self .subgraphs .insert (- (index + 1 ), subgraph )
785- self .nodes |= subgraph .nodes
793+ self .nodes_bitset |= subgraph .nodes_bitset
786794
787795 def __iter__ (self ):
788796 yield from self .subgraphs
@@ -792,32 +800,54 @@ def __iter__(self):
792800 if not fuseable_clients :
793801 return None
794802
795- toposort_idx = {
796- node : idx for idx , node in enumerate (fg .toposort (), start = 1 )
797- }
798- node_ancestors = {None : frozenset ()}
799- for node in toposort_idx :
800- node_ancestors [node ] = frozenset .union (
801- * (node_ancestors [inp .owner ] for inp in node .inputs ), {node }
803+ # Create a bitset of ancestors for each node.
804+ # Each node is represented by a bit flag of it's position in the toposort
805+ # With two variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c,
806+ # the ancestors bit flags would be {A: 0b001, B: 0b010, C: 0b100}
807+ # and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111}
808+ # This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND
809+ # For example, to ask if B is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[B] != 0`
810+ # We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do
811+ # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
812+ fg_clients = fgraph .clients
813+ nodes_bitflags = {node : 1 << i for i , node in enumerate (fgraph .toposort ())}
814+ # Root variables have `None` as owner, which we can handle with a bitset of 0 for `None`
815+ ancestors_bitset = {None : 0 }
816+ for node , node_bitflag in nodes_bitflags .items ():
817+ # The bitset of each node is the union of the bitsets of its inputs, plus its own bit
818+ ancestors_bitset [node ] = reduce (
819+ or_ ,
820+ (ancestors_bitset [inp .owner ] for inp in node .inputs ),
821+ node_bitflag ,
802822 )
823+ # handle root and leaf nodes gracefully
824+ # Root variables have `None` as owner, which we can handle with a bitflag of 0 for `None`
825+ nodes_bitflags [None ] = 0
826+ # Nothing ever depends on output nodes, so just use a new bit for all
827+ out_bitflag = 1 << len (nodes_bitflags )
828+ for out in fg .outputs :
829+ for client , _ in fg_clients [out ]:
830+ if isinstance (client .op , Output ):
831+ nodes_bitflags [client ] = out_bitflag
803832
804- fg_clients = fgraph .clients
805833 sorted_subgraphs = SortedSubgraphCollection ()
806834
807835 # Start exploring from candidate sink nodes (backwards)
808836 # These are Elemwise nodes with a C-implementation, that are not part of another subgraph
809837 # And have no other fuseable clients (i.e., are sinks)
810- for starting_node , starting_idx in reversed (toposort_idx .items ()):
838+ for starting_node , starting_bitflag in reversed (nodes_bitflags .items ()):
811839 if (
812- starting_node in sorted_subgraphs
840+ starting_bitflag in sorted_subgraphs
813841 or not fuseable_clients .is_sink_node (starting_node )
814842 ):
815843 continue
816844
817- subgraph = ConvexSubgraph (node_ancestors )
845+ subgraph = ConvexSubgraph (nodes_bitflags , ancestors_bitset )
818846
819847 fuseable_nodes_queue = SortedFuseableNodesQueue ()
820- fuseable_nodes_queue .push (starting_node , starting_idx , is_ancestor = True )
848+ fuseable_nodes_queue .push (
849+ starting_node , starting_bitflag , is_ancestor = True
850+ )
821851 while fuseable_nodes_queue :
822852 node , is_ancestor = fuseable_nodes_queue .pop ()
823853
@@ -841,7 +871,7 @@ def __iter__(self):
841871 if node in fuseable_clients [ancestor_node ]:
842872 fuseable_nodes_queue .push (
843873 ancestor_node ,
844- toposort_idx [ancestor_node ],
874+ nodes_bitflags [ancestor_node ],
845875 is_ancestor = True ,
846876 )
847877 else :
@@ -854,7 +884,7 @@ def __iter__(self):
854884 if client_node in next_fuseable_clients :
855885 fuseable_nodes_queue .push (
856886 client_node ,
857- toposort_idx [client_node ],
887+ nodes_bitflags [client_node ],
858888 is_ancestor = False ,
859889 )
860890 else :
0 commit comments