22import itertools
33import operator
44import sys
5+ import typing
56from collections import defaultdict , deque
67from collections .abc import Generator , Sequence
78from functools import cache , reduce
@@ -522,6 +523,43 @@ def elemwise_max_operands_fct(node) -> int:
522523 return 1024
523524
524525
526+ class CopyOnWriteDictOfSets :
527+ __slots__ = ("d" , "d_copy" )
528+
529+ def __init__ (self , d : dict [typing .Any , set ]):
530+ self .d = d
531+ self .d_copy : dict [typing .Any , set ] = {}
532+
533+ def __getitem__ (self , key ):
534+ try :
535+ return self .d_copy [key ]
536+ except KeyError :
537+ return self .d [key ]
538+
539+ def get (self , key , default = frozenset ()):
540+ try :
541+ return self .d_copy [key ]
542+ except KeyError :
543+ try :
544+ return self .d [key ]
545+ except KeyError :
546+ return default
547+
548+ def remove_from_key (self , key , value ):
549+ try :
550+ self .d_copy [key ].remove (value )
551+ except KeyError :
552+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
553+ copied_value .remove (value )
554+
555+ def add_to_key (self , key , value ):
556+ try :
557+ self .d_copy [key ].add (value )
558+ except KeyError :
559+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
560+ copied_value .add (value )
561+
562+
525563class FusionOptimizer (GraphRewriter ):
526564 """Graph optimizer that fuses consecutive Elemwise operations."""
527565
@@ -648,15 +686,10 @@ def variables_depend_on(
648686 subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
649687 unfuseable_clients_subgraph : set [Variable ] = set ()
650688
651- # Shallow cloning of maps so that they can be manipulated in place
652- fuseable_clients_clone : FUSEABLE_MAPPING = defaultdict (set )
653- fuseable_clients_clone .update (
654- {k : v .copy () for k , v in fuseable_clients .items ()}
655- )
656- unfuseable_clients_clone : UNFUSEABLE_MAPPING = defaultdict (set )
657- unfuseable_clients_clone .update (
658- {k : v .copy () for k , v in unfuseable_clients .items ()}
659- )
689+ # If we need to manipulate the maps in place, we'll do a shallow copy later
690+ # For now we query on the original ones
691+ fuseable_clients_clone = CopyOnWriteDictOfSets (fuseable_clients )
692+ unfuseable_clients_clone = CopyOnWriteDictOfSets (unfuseable_clients )
660693
661694 # We now try to expand as much as possible towards the potentially
662695 # fuseable clients and ancestors to detect the largest possible
@@ -686,7 +719,7 @@ def variables_depend_on(
686719 required_unfuseable_inputs = [
687720 inp
688721 for inp in next_node .inputs
689- if next_node in unfuseable_clients_clone .get (inp , () )
722+ if next_node in unfuseable_clients_clone .get (inp )
690723 ]
691724 new_required_unfuseable_inputs = [
692725 inp
@@ -709,7 +742,7 @@ def variables_depend_on(
709742 if not must_backtrack :
710743 implied_unfuseable_clients = {
711744 c
712- for client in unfuseable_clients_clone .get (next_out , () )
745+ for client in unfuseable_clients_clone .get (next_out )
713746 if not isinstance (client .op , Output )
714747 for c in client .outputs
715748 }
@@ -730,13 +763,15 @@ def variables_depend_on(
730763
731764 if must_backtrack :
732765 for inp in next_node .inputs :
733- if (
734- inp .owner in visited_nodes
735- # next_node could have the same input repeated
736- and next_node in fuseable_clients_clone [inp ]
737- ):
738- fuseable_clients_clone [inp ].remove (next_node )
739- unfuseable_clients_clone [inp ].add (next_node )
766+ if inp .owner in visited_nodes :
767+ if next_node not in fuseable_clients_clone [inp ]:
768+ # This can happen when next node has repeated inputs
769+ continue
770+ fuseable_clients_clone .remove_from_key (
771+ inp , next_node
772+ )
773+ unfuseable_clients_clone .add_to_key (inp , next_node )
774+
740775 # This input must become an output of the subgraph,
741776 # because it can't be merged with next_node.
742777 # We will revisit it to make sure this is safe.
@@ -745,8 +780,13 @@ def variables_depend_on(
745780 # need to convert to tuple not to change set size during iteration
746781 for client in tuple (fuseable_clients_clone [next_out ]):
747782 if client in visited_nodes :
748- fuseable_clients_clone [next_out ].remove (client )
749- unfuseable_clients_clone [next_out ].add (client )
783+ fuseable_clients_clone .remove_from_key (
784+ next_out , client
785+ )
786+ unfuseable_clients_clone .add_to_key (
787+ next_out , client
788+ )
789+
750790 # next_out must become an input of the subgraph.
751791 # We will revisit any of its clients currently
752792 # in the subgraph to make sure this is safe.
@@ -789,7 +829,7 @@ def variables_depend_on(
789829 sorted (
790830 (
791831 node
792- for node in fuseable_clients_clone .get (next_out , () )
832+ for node in fuseable_clients_clone .get (next_out )
793833 if node not in visited_nodes
794834 ),
795835 key = toposort_index .get , # type: ignore[arg-type]
0 commit comments