55from collections import defaultdict , deque
66from collections .abc import Generator , Sequence
77from functools import cache , reduce
8- from typing import TypeVar
8+ from typing import Literal
99from warnings import warn
1010
1111import pytensor .scalar .basic as ps
@@ -555,8 +555,6 @@ def apply(self, fgraph):
555555 callbacks_before = fgraph .execute_callbacks_times .copy ()
556556 callback_before = fgraph .execute_callbacks_time
557557
558- max_operands = elemwise_max_operands_fct (None )
559-
560558 def find_next_fuseable_subgraph (
561559 fg : FunctionGraph ,
562560 ) -> Generator [tuple [list [Variable ], list [Variable ]], None , None ]:
@@ -568,8 +566,7 @@ def find_next_fuseable_subgraph(
568566 This generator assumes that such subgraph is replaced by a single
569567 Elemwise Composite before being accessed again in the next iteration.
570568 """
571-
572- FUSEABLE_MAPPING = defaultdict [Variable , list [Apply ]]
569+ FUSEABLE_MAPPING = defaultdict [Variable , set [Apply ]]
573570 UNFUSEABLE_MAPPING = defaultdict [Variable , set [Apply ]]
574571
575572 def initialize_fuseable_mappings (
@@ -591,35 +588,31 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
591588 # to ensure the rewrite remains deterministic.
592589 # This is not a problem from unfuseable ones, as they can never
593590 # become part of the graph.
594- fuseable_clients : FUSEABLE_MAPPING = defaultdict (list )
591+ fuseable_clients : FUSEABLE_MAPPING = defaultdict (set )
595592 unfuseable_clients : UNFUSEABLE_MAPPING = defaultdict (set )
596593 for out , clients in fg .clients .items ():
597- # Old FunctionGraph nodes remain in the clients dictionary
598- # even after they are removed by rewrites
599- if not clients :
600- continue
601-
602594 out_maybe_fuseable = (
603- out .owner
595+ out .owner is not None
604596 and isinstance (out .owner .op , Elemwise )
605597 # and not isinstance(out.owner.op.scalar_op, ps.Composite)
606598 and len (out .owner .outputs ) == 1
607599 and elemwise_scalar_op_has_c_code (out .owner )
608600 )
609- for client , _ in clients :
610- if (
611- out_maybe_fuseable
612- and isinstance (client .op , Elemwise )
613- # and not isinstance(client.op.scalar_op, ps.Composite)
614- and len (client .outputs ) == 1
615- and out .type .broadcastable
616- == client .outputs [0 ].type .broadcastable
617- and elemwise_scalar_op_has_c_code (client )
618- ):
619- if client not in fuseable_clients [out ]:
620- fuseable_clients [out ].append (client )
621- else :
622- unfuseable_clients [out ].add (client )
601+ if out_maybe_fuseable :
602+ out_bcast = out .type .broadcastable
603+ for client , _ in clients :
604+ if (
605+ isinstance (client .op , Elemwise )
606+ # and not isinstance(client.op.scalar_op, ps.Composite)
607+ and len (client .outputs ) == 1
608+ and out_bcast == client .outputs [0 ].type .broadcastable
609+ and elemwise_scalar_op_has_c_code (client )
610+ ):
611+ fuseable_clients [out ].add (client )
612+ else :
613+ unfuseable_clients [out ].add (client )
614+ else :
615+ unfuseable_clients [out ] = {client for client , _ in clients }
623616
624617 return fuseable_clients , unfuseable_clients
625618
@@ -630,16 +623,6 @@ def find_fuseable_subgraph(
630623 unfuseable_clients : UNFUSEABLE_MAPPING ,
631624 toposort_index : dict [Apply , int ],
632625 ) -> tuple [list [Variable ], list [Variable ]]:
633- KT = TypeVar ("KT" )
634- VT = TypeVar ("VT" , list , set )
635-
636- def shallow_clone_defaultdict (
637- d : defaultdict [KT , VT ],
638- ) -> defaultdict [KT , VT ]:
639- new_dict : defaultdict [KT , VT ] = defaultdict (d .default_factory )
640- new_dict .update ({k : v .copy () for k , v in d .items ()})
641- return new_dict
642-
643626 def variables_depend_on (
644627 variables , depend_on , stop_search_at = None
645628 ) -> bool :
@@ -657,17 +640,19 @@ def variables_depend_on(
657640 visited_nodes .add (starting_node )
658641 continue
659642
660- subgraph_inputs : list [Variable ] = []
661- subgraph_outputs : list [Variable ] = []
643+ subgraph_inputs : dict [Variable , Literal [ None ]] = {} # ordered set
644+ subgraph_outputs : dict [Variable , Literal [ None ]] = {} # ordered set
662645 unfuseable_clients_subgraph : set [Variable ] = set ()
663646
664647 # Shallow cloning of maps so that they can be manipulated in place
665- fuseable_clients_temp = shallow_clone_defaultdict (fuseable_clients )
666- unfuseable_clients_clone = shallow_clone_defaultdict (
667- unfuseable_clients
648+ fuseable_clients_clone : FUSEABLE_MAPPING = defaultdict (set )
649+ fuseable_clients_clone .update (
650+ {k : v .copy () for k , v in fuseable_clients .items ()}
651+ )
652+ unfuseable_clients_clone : UNFUSEABLE_MAPPING = defaultdict (set )
653+ unfuseable_clients_clone .update (
654+ {k : v .copy () for k , v in unfuseable_clients .items ()}
668655 )
669-
670- fuseable_nodes_to_visit = deque ([starting_node ])
671656
672657 # We now try to expand as much as possible towards the potentially
673658 # fuseable clients and ancestors to detect the largest possible
@@ -676,6 +661,7 @@ def variables_depend_on(
676661 # some inputs or clients may depend on other nodes of the same
677662 # subgraph via a path that cannot be included in the Composite
678663 # (unfuseable)
664+ fuseable_nodes_to_visit = deque ([starting_node ])
679665 while fuseable_nodes_to_visit :
680666 next_node = fuseable_nodes_to_visit .popleft ()
681667 visited_nodes .add (next_node )
@@ -684,15 +670,14 @@ def variables_depend_on(
684670 # If the output variable of next_node has no fuseable clients
685671 # or has unfuseable clients, then next_node must become an output
686672 # if it is to be fused.
687- must_become_output = (
688- next_out not in fuseable_clients_temp
689- or next_out in unfuseable_clients_clone
690- )
673+ must_become_output = not fuseable_clients_clone .get (
674+ next_out
675+ ) or unfuseable_clients_clone .get (next_out )
691676
692677 # We have backtracked to this node, and it may no longer be a viable output,
693678 # so we remove it and check again as if we had never seen this node
694- if must_become_output and next_out in subgraph_outputs :
695- subgraph_outputs .remove (next_out )
679+ if must_become_output :
680+ subgraph_outputs .pop (next_out , None )
696681
697682 required_unfuseable_inputs = [
698683 inp
@@ -744,18 +729,19 @@ def variables_depend_on(
744729 if (
745730 inp .owner in visited_nodes
746731 # next_node could have the same input repeated
747- and next_node in fuseable_clients_temp [inp ]
732+ and next_node in fuseable_clients_clone [inp ]
748733 ):
749- fuseable_clients_temp [inp ].remove (next_node )
734+ fuseable_clients_clone [inp ].remove (next_node )
750735 unfuseable_clients_clone [inp ].add (next_node )
751736 # This input must become an output of the subgraph,
752737 # because it can't be merged with next_node.
753738 # We will revisit it to make sure this is safe.
754739 fuseable_nodes_to_visit .appendleft (inp .owner )
755740
756- for client in fuseable_clients_temp [next_out ]:
741+ # need to convert to tuple not to change set size during iteration
742+ for client in tuple (fuseable_clients_clone [next_out ]):
757743 if client in visited_nodes :
758- fuseable_clients_temp [next_out ].remove (client )
744+ fuseable_clients_clone [next_out ].remove (client )
759745 unfuseable_clients_clone [next_out ].add (client )
760746 # next_out must become an input of the subgraph.
761747 # We will revisit any of its clients currently
@@ -771,74 +757,72 @@ def variables_depend_on(
771757 # mappings as if it next_node was part of it.
772758 # Useless inputs will be removed by the useless Composite rewrite
773759 for inp in new_required_unfuseable_inputs :
774- if inp not in subgraph_inputs :
775- subgraph_inputs .append (inp )
760+ subgraph_inputs [inp ] = None
776761
777762 if must_become_output :
778- subgraph_outputs . append ( next_out )
763+ subgraph_outputs [ next_out ] = None
779764 unfuseable_clients_subgraph .update (
780765 new_implied_unfuseable_clients
781766 )
782767
783768 # Expand through unvisited fuseable ancestors
784- for inp in sorted (
785- (
786- inp
787- for inp in next_node . inputs
788- if (
789- inp not in required_unfuseable_inputs
790- and inp . owner not in visited_nodes
791- )
792- ),
793- key = lambda inp : toposort_index [ inp . owner ] ,
794- reverse = True ,
795- ):
796- fuseable_nodes_to_visit . appendleft ( inp . owner )
769+ fuseable_nodes_to_visit . extendleft (
770+ sorted (
771+ (
772+ inp . owner
773+ for inp in next_node . inputs
774+ if (
775+ inp not in required_unfuseable_inputs
776+ and inp . owner not in visited_nodes
777+ )
778+ ) ,
779+ key = toposort_index . get , # type: ignore[arg-type]
780+ )
781+ )
797782
798783 # Expand through unvisited fuseable clients
799- for next_node in sorted (
800- (
801- node
802- for node in fuseable_clients_temp .get (next_out , ())
803- if node not in visited_nodes
804- ),
805- key = lambda node : toposort_index [node ],
806- ):
807- fuseable_nodes_to_visit .append (next_node )
784+ fuseable_nodes_to_visit .extend (
785+ sorted (
786+ (
787+ node
788+ for node in fuseable_clients_clone .get (next_out , ())
789+ if node not in visited_nodes
790+ ),
791+ key = toposort_index .get , # type: ignore[arg-type]
792+ )
793+ )
808794
809795 # Don't return if final subgraph is just the original Elemwise
810796 if len (subgraph_outputs ) == 1 and set (
811- subgraph_outputs [ 0 ] .owner .inputs
797+ next ( iter ( subgraph_outputs )) .owner .inputs
812798 ) == set (subgraph_inputs ):
813799 # Update global fuseable mappings
814800 # No input was actually fuseable
815801 for inp in starting_node .inputs :
816- if starting_node in fuseable_clients .get (inp , ()):
817- fuseable_clients [inp ].remove (starting_node )
818- unfuseable_clients [inp ].add (starting_node )
802+ fuseable_clients [inp ].discard (starting_node )
803+ unfuseable_clients [inp ].add (starting_node )
819804 # No client was actually fuseable
820805 unfuseable_clients [starting_out ].update (
821806 fuseable_clients .pop (starting_out , ())
822807 )
823808 continue
824809
825- return subgraph_inputs , subgraph_outputs
810+ return list ( subgraph_inputs ), list ( subgraph_outputs )
826811 raise ValueError
827812
828813 def update_fuseable_mappings_after_fg_replace (
829814 * ,
830- fg : FunctionGraph ,
831815 visited_nodes : set [Apply ],
832816 fuseable_clients : FUSEABLE_MAPPING ,
833817 unfuseable_clients : UNFUSEABLE_MAPPING ,
834818 starting_nodes : set [Apply ],
819+ updated_nodes : set [Apply ],
835820 ) -> None :
836821 # Find new composite node and dropped intermediate nodes
837822 # by comparing the current fg.apply nodes with the cached
838823 # original nodes
839- next_nodes = fg .apply_nodes
840- (new_composite_node ,) = next_nodes - starting_nodes
841- dropped_nodes = starting_nodes - next_nodes
824+ (new_composite_node ,) = updated_nodes - starting_nodes
825+ dropped_nodes = starting_nodes - updated_nodes
842826
843827 # Remove intermediate Composite nodes from mappings
844828 for dropped_node in dropped_nodes :
@@ -850,11 +834,11 @@ def update_fuseable_mappings_after_fg_replace(
850834 # Update fuseable information for subgraph inputs
851835 for inp in subgraph_inputs :
852836 if inp in fuseable_clients :
853- new_fuseable_clients = [
837+ new_fuseable_clients = {
854838 client
855839 for client in fuseable_clients [inp ]
856840 if client not in dropped_nodes
857- ]
841+ }
858842 if new_fuseable_clients :
859843 fuseable_clients [inp ] = new_fuseable_clients
860844 else :
@@ -898,13 +882,15 @@ def update_fuseable_mappings_after_fg_replace(
898882 # generator. For large models (as in `TestFusion.test_big_fusion`)
899883 # this can provide huge speedups
900884 update_fuseable_mappings_after_fg_replace (
901- fg = fg ,
902885 visited_nodes = visited_nodes ,
903886 fuseable_clients = fuseable_clients ,
904887 unfuseable_clients = unfuseable_clients ,
905888 starting_nodes = starting_nodes ,
889+ updated_nodes = fg .apply_nodes ,
906890 )
907891
892+ max_operands = elemwise_max_operands_fct (None )
893+ reason = self .__class__ .__name__
908894 nb_fused = 0
909895 nb_replacement = 0
910896 for inputs , outputs in find_next_fuseable_subgraph (fgraph ):
@@ -923,13 +909,12 @@ def update_fuseable_mappings_after_fg_replace(
923909 assert len (outputs ) == len (composite_outputs )
924910 for old_out , composite_out in zip (outputs , composite_outputs ):
925911 # Preserve any names on the original outputs
926- if old_out .name :
927- composite_out .name = old_out . name
912+ if old_name := old_out .name :
913+ composite_out .name = old_name
928914
929915 starting_nodes = len (fgraph .apply_nodes )
930916 fgraph .replace_all_validate (
931- list (zip (outputs , composite_outputs , strict = True )),
932- reason = self .__class__ .__name__ ,
917+ tuple (zip (outputs , composite_outputs )), reason = reason
933918 )
934919 nb_fused += 1
935920 nb_replacement += (starting_nodes - len (fgraph .apply_nodes )) + 1
0 commit comments