55from collections import defaultdict , deque
66from collections .abc import Generator , Sequence
77from functools import cache , reduce
8- from typing import TypeVar
8+ from typing import Literal
99from warnings import warn
1010
1111import pytensor .scalar .basic as ps
@@ -568,8 +568,7 @@ def find_next_fuseable_subgraph(
568568 This generator assumes that such subgraph is replaced by a single
569569 Elemwise Composite before being accessed again in the next iteration.
570570 """
571-
572- FUSEABLE_MAPPING = defaultdict [Variable , list [Apply ]]
571+ FUSEABLE_MAPPING = defaultdict [Variable , set [Apply ]]
573572 UNFUSEABLE_MAPPING = defaultdict [Variable , set [Apply ]]
574573
575574 def initialize_fuseable_mappings (
@@ -591,35 +590,33 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
591590 # to ensure the rewrite remains deterministic.
592591 # This is not a problem from unfuseable ones, as they can never
593592 # become part of the graph.
594- fuseable_clients : FUSEABLE_MAPPING = defaultdict (list )
593+ fuseable_clients : FUSEABLE_MAPPING = defaultdict (set )
595594 unfuseable_clients : UNFUSEABLE_MAPPING = defaultdict (set )
596595 for out , clients in fg .clients .items ():
597- # Old FunctionGraph nodes remain in the clients dictionary
598- # even after they are removed by rewrites
599- if not clients :
600- continue
601-
602596 out_maybe_fuseable = (
603- out .owner
597+ out .owner is not None
604598 and isinstance (out .owner .op , Elemwise )
605599 # and not isinstance(out.owner.op.scalar_op, ps.Composite)
606600 and len (out .owner .outputs ) == 1
607601 and elemwise_scalar_op_has_c_code (out .owner )
608602 )
609- for client , _ in clients :
610- if (
611- out_maybe_fuseable
612- and isinstance (client .op , Elemwise )
613- # and not isinstance(client.op.scalar_op, ps.Composite)
614- and len (client .outputs ) == 1
615- and out .type .broadcastable
616- == client .outputs [0 ].type .broadcastable
617- and elemwise_scalar_op_has_c_code (client )
618- ):
619- if client not in fuseable_clients [out ]:
620- fuseable_clients [out ].append (client )
621- else :
622- unfuseable_clients [out ].add (client )
603+ if out_maybe_fuseable :
604+ out_bcast = (
605+ out .type .broadcastable if out_maybe_fuseable else None
606+ )
607+ for client , _ in clients :
608+ if (
609+ isinstance (client .op , Elemwise )
610+ # and not isinstance(client.op.scalar_op, ps.Composite)
611+ and len (client .outputs ) == 1
612+ and out_bcast == client .outputs [0 ].type .broadcastable
613+ and elemwise_scalar_op_has_c_code (client )
614+ ):
615+ fuseable_clients [out ].add (client )
616+ else :
617+ unfuseable_clients [out ].add (client )
618+ else :
619+ unfuseable_clients [out ] = {client for client , _ in clients }
623620
624621 return fuseable_clients , unfuseable_clients
625622
@@ -630,16 +627,6 @@ def find_fuseable_subgraph(
630627 unfuseable_clients : UNFUSEABLE_MAPPING ,
631628 toposort_index : dict [Apply , int ],
632629 ) -> tuple [list [Variable ], list [Variable ]]:
633- KT = TypeVar ("KT" )
634- VT = TypeVar ("VT" , list , set )
635-
636- def shallow_clone_defaultdict (
637- d : defaultdict [KT , VT ],
638- ) -> defaultdict [KT , VT ]:
639- new_dict : defaultdict [KT , VT ] = defaultdict (d .default_factory )
640- new_dict .update ({k : v .copy () for k , v in d .items ()})
641- return new_dict
642-
643630 def variables_depend_on (
644631 variables , depend_on , stop_search_at = None
645632 ) -> bool :
@@ -657,17 +644,19 @@ def variables_depend_on(
657644 visited_nodes .add (starting_node )
658645 continue
659646
660- subgraph_inputs : list [Variable ] = []
661- subgraph_outputs : list [Variable ] = []
647+ subgraph_inputs : dict [Variable , Literal [ None ]] = {} # ordered set
648+ subgraph_outputs : dict [Variable , Literal [ None ]] = {} # ordered set
662649 unfuseable_clients_subgraph : set [Variable ] = set ()
663650
664651 # Shallow cloning of maps so that they can be manipulated in place
665- fuseable_clients_temp = shallow_clone_defaultdict (fuseable_clients )
666- unfuseable_clients_clone = shallow_clone_defaultdict (
667- unfuseable_clients
652+ fuseable_clients_clone : FUSEABLE_MAPPING = defaultdict (set )
653+ fuseable_clients_clone .update (
654+ {k : v .copy () for k , v in fuseable_clients .items ()}
655+ )
656+ unfuseable_clients_clone : UNFUSEABLE_MAPPING = defaultdict (set )
657+ unfuseable_clients_clone .update (
658+ {k : v .copy () for k , v in unfuseable_clients .items ()}
668659 )
669-
670- fuseable_nodes_to_visit = deque ([starting_node ])
671660
672661 # We now try to expand as much as possible towards the potentially
673662 # fuseable clients and ancestors to detect the largest possible
@@ -676,6 +665,7 @@ def variables_depend_on(
676665 # some inputs or clients may depend on other nodes of the same
677666 # subgraph via a path that cannot be included in the Composite
678667 # (unfuseable)
668+ fuseable_nodes_to_visit = deque ([starting_node ])
679669 while fuseable_nodes_to_visit :
680670 next_node = fuseable_nodes_to_visit .popleft ()
681671 visited_nodes .add (next_node )
@@ -684,15 +674,14 @@ def variables_depend_on(
684674 # If the output variable of next_node has no fuseable clients
685675 # or has unfuseable clients, then next_node must become an output
686676 # if it is to be fused.
687- must_become_output = (
688- next_out not in fuseable_clients_temp
689- or next_out in unfuseable_clients_clone
690- )
677+ must_become_output = not fuseable_clients_clone .get (
678+ next_out
679+ ) or unfuseable_clients_clone .get (next_out )
691680
692681 # We have backtracked to this node, and it may no longer be a viable output,
693682 # so we remove it and check again as if we had never seen this node
694- if must_become_output and next_out in subgraph_outputs :
695- subgraph_outputs .remove (next_out )
683+ if must_become_output :
684+ subgraph_outputs .pop (next_out , None )
696685
697686 required_unfuseable_inputs = [
698687 inp
@@ -744,18 +733,19 @@ def variables_depend_on(
744733 if (
745734 inp .owner in visited_nodes
746735 # next_node could have the same input repeated
747- and next_node in fuseable_clients_temp [inp ]
736+ and next_node in fuseable_clients_clone [inp ]
748737 ):
749- fuseable_clients_temp [inp ].remove (next_node )
738+ fuseable_clients_clone [inp ].remove (next_node )
750739 unfuseable_clients_clone [inp ].add (next_node )
751740 # This input must become an output of the subgraph,
752741 # because it can't be merged with next_node.
753742 # We will revisit it to make sure this is safe.
754743 fuseable_nodes_to_visit .appendleft (inp .owner )
755744
756- for client in fuseable_clients_temp [next_out ]:
745+ # need to convert to tuple not to change set size during iteration
746+ for client in tuple (fuseable_clients_clone [next_out ]):
757747 if client in visited_nodes :
758- fuseable_clients_temp [next_out ].remove (client )
748+ fuseable_clients_clone [next_out ].remove (client )
759749 unfuseable_clients_clone [next_out ].add (client )
760750 # next_out must become an input of the subgraph.
761751 # We will revisit any of its clients currently
@@ -771,74 +761,72 @@ def variables_depend_on(
771761 # mappings as if it next_node was part of it.
772762 # Useless inputs will be removed by the useless Composite rewrite
773763 for inp in new_required_unfuseable_inputs :
774- if inp not in subgraph_inputs :
775- subgraph_inputs .append (inp )
764+ subgraph_inputs [inp ] = None
776765
777766 if must_become_output :
778- subgraph_outputs . append ( next_out )
767+ subgraph_outputs [ next_out ] = None
779768 unfuseable_clients_subgraph .update (
780769 new_implied_unfuseable_clients
781770 )
782771
783772 # Expand through unvisited fuseable ancestors
784- for inp in sorted (
785- (
786- inp
787- for inp in next_node . inputs
788- if (
789- inp not in required_unfuseable_inputs
790- and inp . owner not in visited_nodes
791- )
792- ),
793- key = lambda inp : toposort_index [ inp . owner ] ,
794- reverse = True ,
795- ):
796- fuseable_nodes_to_visit . appendleft ( inp . owner )
773+ fuseable_nodes_to_visit . extendleft (
774+ sorted (
775+ (
776+ inp . owner
777+ for inp in next_node . inputs
778+ if (
779+ inp not in required_unfuseable_inputs
780+ and inp . owner not in visited_nodes
781+ )
782+ ) ,
783+ key = toposort_index . get , # type: ignore[arg-type]
784+ )
785+ )
797786
798787 # Expand through unvisited fuseable clients
799- for next_node in sorted (
800- (
801- node
802- for node in fuseable_clients_temp .get (next_out , ())
803- if node not in visited_nodes
804- ),
805- key = lambda node : toposort_index [node ],
806- ):
807- fuseable_nodes_to_visit .append (next_node )
788+ fuseable_nodes_to_visit .extend (
789+ sorted (
790+ (
791+ node
792+ for node in fuseable_clients_clone .get (next_out , ())
793+ if node not in visited_nodes
794+ ),
795+ key = toposort_index .get , # type: ignore[arg-type]
796+ )
797+ )
808798
809799 # Don't return if final subgraph is just the original Elemwise
810800 if len (subgraph_outputs ) == 1 and set (
811- subgraph_outputs [ 0 ] .owner .inputs
801+ next ( iter ( subgraph_outputs )) .owner .inputs
812802 ) == set (subgraph_inputs ):
813803 # Update global fuseable mappings
814804 # No input was actually fuseable
815805 for inp in starting_node .inputs :
816- if starting_node in fuseable_clients .get (inp , ()):
817- fuseable_clients [inp ].remove (starting_node )
818- unfuseable_clients [inp ].add (starting_node )
806+ fuseable_clients [inp ].discard (starting_node )
807+ unfuseable_clients [inp ].add (starting_node )
819808 # No client was actually fuseable
820809 unfuseable_clients [starting_out ].update (
821810 fuseable_clients .pop (starting_out , ())
822811 )
823812 continue
824813
825- return subgraph_inputs , subgraph_outputs
814+ return list ( subgraph_inputs ), list ( subgraph_outputs )
826815 raise ValueError
827816
828817 def update_fuseable_mappings_after_fg_replace (
829818 * ,
830- fg : FunctionGraph ,
831819 visited_nodes : set [Apply ],
832820 fuseable_clients : FUSEABLE_MAPPING ,
833821 unfuseable_clients : UNFUSEABLE_MAPPING ,
834822 starting_nodes : set [Apply ],
823+ updated_nodes : set [Apply ],
835824 ) -> None :
836825 # Find new composite node and dropped intermediate nodes
837826 # by comparing the current fg.apply nodes with the cached
838827 # original nodes
839- next_nodes = fg .apply_nodes
840- (new_composite_node ,) = next_nodes - starting_nodes
841- dropped_nodes = starting_nodes - next_nodes
828+ (new_composite_node ,) = updated_nodes - starting_nodes
829+ dropped_nodes = starting_nodes - updated_nodes
842830
843831 # Remove intermediate Composite nodes from mappings
844832 for dropped_node in dropped_nodes :
@@ -850,11 +838,11 @@ def update_fuseable_mappings_after_fg_replace(
850838 # Update fuseable information for subgraph inputs
851839 for inp in subgraph_inputs :
852840 if inp in fuseable_clients :
853- new_fuseable_clients = [
841+ new_fuseable_clients = {
854842 client
855843 for client in fuseable_clients [inp ]
856844 if client not in dropped_nodes
857- ]
845+ }
858846 if new_fuseable_clients :
859847 fuseable_clients [inp ] = new_fuseable_clients
860848 else :
@@ -898,11 +886,11 @@ def update_fuseable_mappings_after_fg_replace(
898886 # generator. For large models (as in `TestFusion.test_big_fusion`)
899887 # this can provide huge speedups
900888 update_fuseable_mappings_after_fg_replace (
901- fg = fg ,
902889 visited_nodes = visited_nodes ,
903890 fuseable_clients = fuseable_clients ,
904891 unfuseable_clients = unfuseable_clients ,
905892 starting_nodes = starting_nodes ,
893+ updated_nodes = fg .apply_nodes ,
906894 )
907895
908896 nb_fused = 0
0 commit comments