55from collections .abc import Generator , Sequence
66from functools import cache , reduce
77from heapq import heapify , heappop , heappush
8- from operator import or_
98from warnings import warn
109
1110import pytensor .scalar .basic as ps
1615from pytensor .graph .basic import Apply , Variable
1716from pytensor .graph .destroyhandler import DestroyHandler , inplace_candidates
1817from pytensor .graph .features import ReplaceValidate
19- from pytensor .graph .fg import FunctionGraph , Output
18+ from pytensor .graph .fg import FunctionGraph
2019from pytensor .graph .op import Op
2120from pytensor .graph .rewriting .basic import (
2221 GraphRewriter ,
@@ -621,48 +620,28 @@ def elemwise_scalar_op_has_c_code(
621620 if not fuseable_clients :
622621 return None
623622
624- # Create a bitset of ancestors for each node.
625- # Each node is represented by a bit flag of it's position in the toposort
626- # With two variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c,
627- # the ancestors bit flags would be {A: 0b001, B: 0b010, C: 0b100}
628- # and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111}
629- # This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND
630- # For example, to ask if B is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[B] != 0`
631- # We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do
632- # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
633- nodes_bitflags = {node : 1 << i for i , node in enumerate (fgraph .toposort ())}
634- ancestors_bitset = {
635- None : 0
636- } # Root variables have `None` as owner, which we can handle with a bitset of 0 for `None`
637- for node , node_bitflag in nodes_bitflags .items ():
638- # The bitset of each node is the union of the bitsets of its inputs, plus its own bit
639- ancestors_bitset [node ] = reduce (
640- or_ ,
641- (ancestors_bitset [inp .owner ] for inp in node .inputs ),
642- node_bitflag ,
623+ toposort_idx = {
624+ node : idx for idx , node in enumerate (fg .toposort (), start = 1 )
625+ }
626+ node_ancestors = {None : frozenset ()}
627+ for node in toposort_idx :
628+ node_ancestors [node ] = frozenset .union (
629+ * (node_ancestors [inp .owner ] for inp in node .inputs ), {node }
643630 )
644- # handle root and leaf nodes gracefully
645- nodes_bitflags [None ] = (
646- 0 # Root variables have `None` as owner, which we can handle with a bitflag of 0 for `None`
647- )
648- out_bitflag = 1 << len (
649- nodes_bitflags
650- ) # Nothing ever depends on output nodes, so just use a new bit for all
651- for out in fg .outputs :
652- for client , _ in fg_clients [out ]:
653- if isinstance (client .op , Output ):
654- nodes_bitflags [client ] = out_bitflag
655631
656632 sorted_subgraphs : list [
657633 tuple [int , tuple [tuple [Variable ], tuple [Variable ]]]
658634 ] = []
659- all_subgraphs_bitset = 0
635+ subgraph_set = set ()
636+ unfuseable_ancestors_set = set ()
637+ unfuseable_clients_set = set ()
638+ all_subgraphs_set = set ()
660639 # Start exploring from candidate sink nodes (backwards)
661640 # These are Elemwise nodes with a C-implementation, that are not part of another subgraph
662641 # And have no other fuseable clients (i.e., are sinks)
663- for starting_node , starting_bitflag in reversed (nodes_bitflags .items ()):
642+ for starting_node , starting_index in reversed (toposort_idx .items ()):
664643 if (
665- starting_bitflag & all_subgraphs_bitset
644+ starting_node in all_subgraphs_set
666645 or starting_node not in candidate_nodes
667646 ):
668647 continue
@@ -676,7 +655,7 @@ def elemwise_scalar_op_has_c_code(
676655 # For ancestors, we want to visit the later nodes first (those that have more dependencies)
677656 # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies)
678657 # We negate the bitflag for ancestors to achieve this ordering
679- fuseables_nodes_queue = [(- starting_bitflag , starting_node )]
658+ fuseables_nodes_queue = [(- starting_index , starting_node )]
680659 heapify (fuseables_nodes_queue )
681660
682661 # We keep 3 bitsets during the exploration:
@@ -687,37 +666,35 @@ def elemwise_scalar_op_has_c_code(
687666 # in which case we can't fuse it. If we can fuse it, we then add its unfuseable ancestors/clients to the respective bitsets
688667 # and add its fuseable ancestors/clients to the queue to explore later. This approach requires a visit in the order described above.
689668 # Otherwise, we need to recompute target bitsets in every iteration and/or backtrack.
690- subgraph_nodes = []
691- subgraph_bitset = 0
692- unfuseable_ancestors_bitset = 0
693- unfuseable_clients_bitset = 0
669+ subgraph_set .clear ()
670+ unfuseable_ancestors_set .clear ()
671+ unfuseable_clients_set .clear ()
694672
695673 # print(f"\nStarting new subgraph exploration from {starting_node}")
696674 while fuseables_nodes_queue :
697- node_bitflag , node = heappop (fuseables_nodes_queue )
698- is_ancestor = node_bitflag < 0
675+ node_idx , node = heappop (fuseables_nodes_queue )
676+ is_ancestor = node_idx < 0
699677 if is_ancestor :
700- node_bitflag = - node_bitflag
678+ node_idx = - node_idx
701679 # print(f"\t > Visiting {'ancestor' if is_ancestor else 'client'} {next_node}")
702680
703- if node_bitflag & subgraph_bitset :
681+ if node in subgraph_set :
704682 # Already part of the subgraph
705683 # print("\t - already in subgraph")
706684 continue
707685
708686 if is_ancestor :
709- if node_bitflag & unfuseable_ancestors_bitset :
687+ if node in unfuseable_ancestors_set :
710688 # An unfuseable ancestor depends on this node, can't fuse
711689 # print("\t failed - unfuseable ancestor depends on it")
712690 continue
713- elif ancestors_bitset [node ] & unfuseable_clients_bitset :
691+ elif not node_ancestors [node ]. isdisjoint ( unfuseable_clients_set ) :
714692 # This node depends on an unfuseable client, can't fuse
715693 # print("\t failed - depends on unfuseable client")
716694 continue
717695
718696 # print("\t succeeded - adding to subgraph")
719- subgraph_nodes .append (node )
720- subgraph_bitset |= node_bitflag
697+ subgraph_set .add (node )
721698
722699 # Expand through ancestors and client nodes
723700 # A node can either be:
@@ -726,57 +703,56 @@ def elemwise_scalar_op_has_c_code(
726703 # - unfuseable (add to respective unfuseable bitset)
727704 for ancestor in node .inputs :
728705 ancestor_node = ancestor .owner
729- ancestor_bitflag = nodes_bitflags [ancestor_node ]
730- if ancestor_bitflag & subgraph_bitset :
706+ if ancestor_node in subgraph_set :
731707 continue
732708 if node in fuseable_clients .get (ancestor_node , ()):
733709 heappush (
734710 fuseables_nodes_queue ,
735- (- ancestor_bitflag , ancestor_node ),
711+ (- toposort_idx [ ancestor_node ] , ancestor_node ),
736712 )
737713 else :
738714 # If an ancestor is unfuseable, so are all its ancestors
739- unfuseable_ancestors_bitset |= ancestors_bitset [
740- ancestor_node
741- ]
715+ unfuseable_ancestors_set |= node_ancestors [ancestor_node ]
742716
743717 next_fuseable_clients = fuseable_clients .get (node , ())
744718 for client , _ in fg_clients [node .outputs [0 ]]:
745- client_bitflag = nodes_bitflags [client ]
746- if client_bitflag & subgraph_bitset :
719+ if client in subgraph_set :
747720 continue
748721 if client in next_fuseable_clients :
749- heappush (fuseables_nodes_queue , (client_bitflag , client ))
722+ heappush (
723+ fuseables_nodes_queue , (toposort_idx [client ], client )
724+ )
750725 else :
751726 # If a client is unfuseable, so are all its clients, but we don't need to keep track of those
752727 # Any downstream client will also depend on this unfuseable client and will be rejected when visited
753- unfuseable_clients_bitset |= client_bitflag
728+ unfuseable_clients_set . add ( client )
754729
755730 # Finished exploring this subgraph
756- all_subgraphs_bitset |= subgraph_bitset
731+ all_subgraphs_set |= subgraph_set
757732
758- if subgraph_bitset == starting_bitflag :
733+ if len ( subgraph_set ) == 1 :
759734 # No fusion possible, single node subgraph
760735 continue
761736
762737 # Find out inputs/outputs of subgraph_nodes
763- not_subgraph_bitset = ~ subgraph_bitset
738+ # not_subgraph_bitset = ~subgraph_set
764739 # Use a dict to deduplicate while preserving order
740+ subgraph_nodes = sorted (subgraph_set , key = toposort_idx .get )
765741 subgraph_inputs = tuple (
766742 dict .fromkeys (
767743 inp
768744 for node in subgraph_nodes
769745 for inp in node .inputs
770746 if (ancestor_node := inp .owner ) is None
771- or nodes_bitflags [ ancestor_node ] & not_subgraph_bitset
747+ or ancestor_node not in subgraph_set
772748 )
773749 )
774750
775751 subgraph_outputs = tuple (
776752 node .outputs [0 ]
777753 for node in subgraph_nodes
778754 if any (
779- nodes_bitflags [ client ] & not_subgraph_bitset
755+ client not in subgraph_set
780756 for client , _ in fg_clients [node .outputs [0 ]]
781757 )
782758 )
@@ -786,26 +762,24 @@ def elemwise_scalar_op_has_c_code(
786762
787763 # Usually new subgraphs don't depend on previous subgraphs, so we can just append them at the end
788764 # But in some cases they can, so we need to insert at the right position.
789- if not (unfuseable_ancestors_bitset & all_subgraphs_bitset ):
765+ if not (unfuseable_ancestors_set & all_subgraphs_set ):
790766 sorted_subgraphs .append (
791- (subgraph_bitset , (subgraph_inputs , subgraph_outputs ))
767+ (subgraph_set , (subgraph_inputs , subgraph_outputs ))
792768 )
793769 else :
794770 # Iterate from the end, removing the bitsets of each previous subgraphs until our current subgraph
795771 # no longer depends on what's left. This tells us where to insert the current subgraph.
796- remaining_subgraphs_bitset = all_subgraphs_bitset
797- for index , (other_subgraph_bitset , _ ) in enumerate (
772+ remaining_subgraphs_bitset = all_subgraphs_set . copy ()
773+ for index , (other_subgraph_set , _ ) in enumerate (
798774 reversed (sorted_subgraphs )
799775 ):
800- remaining_subgraphs_bitset &= ~ other_subgraph_bitset
801- if not (
802- unfuseable_ancestors_bitset & remaining_subgraphs_bitset
803- ):
776+ remaining_subgraphs_bitset .difference_update (other_subgraph_set )
777+ if not (unfuseable_ancestors_set & remaining_subgraphs_bitset ):
804778 break
805779
806780 sorted_subgraphs .insert (
807781 - (index + 1 ),
808- (subgraph_bitset , (subgraph_inputs , subgraph_outputs )),
782+ (subgraph_set , (subgraph_inputs , subgraph_outputs )),
809783 )
810784
811785 # Update fuseable clients, inputs can no longer be fused with graph variables
0 commit comments