22import itertools
33import operator
44import sys
5+ import typing
56from collections import defaultdict , deque
67from collections .abc import Generator , Sequence
78from functools import cache , reduce
@@ -522,6 +523,43 @@ def elemwise_max_operands_fct(node) -> int:
522523 return 1024
523524
524525
526+ class CopyOnWriteDictOfSets :
527+ __slots__ = ("d" , "d_copy" )
528+
529+ def __init__ (self , d : dict [typing .Any , set ]):
530+ self .d = d
531+ self .d_copy : dict [typing .Any , set ] = {}
532+
533+ def __getitem__ (self , key ):
534+ try :
535+ return self .d_copy [key ]
536+ except KeyError :
537+ return self .d [key ]
538+
539+ def get (self , key , default = frozenset ()):
540+ try :
541+ return self .d_copy [key ]
542+ except KeyError :
543+ try :
544+ return self .d [key ]
545+ except KeyError :
546+ return default
547+
548+ def remove_from_key (self , key , value ):
549+ try :
550+ self .d_copy [key ].remove (value )
551+ except KeyError :
552+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
553+ copied_value .remove (value )
554+
555+ def add_to_key (self , key , value ):
556+ try :
557+ self .d_copy [key ].add (value )
558+ except KeyError :
559+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
560+ copied_value .add (value )
561+
562+
525563class FusionOptimizer (GraphRewriter ):
526564 """Graph optimizer that fuses consecutive Elemwise operations."""
527565
@@ -644,15 +682,10 @@ def variables_depend_on(
644682 subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
645683 unfuseable_clients_subgraph : set [Variable ] = set ()
646684
647- # Shallow cloning of maps so that they can be manipulated in place
648- fuseable_clients_clone : FUSEABLE_MAPPING = defaultdict (set )
649- fuseable_clients_clone .update (
650- {k : v .copy () for k , v in fuseable_clients .items ()}
651- )
652- unfuseable_clients_clone : UNFUSEABLE_MAPPING = defaultdict (set )
653- unfuseable_clients_clone .update (
654- {k : v .copy () for k , v in unfuseable_clients .items ()}
655- )
685+ # If we need to manipulate the maps in place, we'll do a shallow copy later
686+ # For now we query on the original ones
687+ fuseable_clients_clone = CopyOnWriteDictOfSets (fuseable_clients )
688+ unfuseable_clients_clone = CopyOnWriteDictOfSets (unfuseable_clients )
656689
657690 # We now try to expand as much as possible towards the potentially
658691 # fuseable clients and ancestors to detect the largest possible
@@ -682,7 +715,7 @@ def variables_depend_on(
682715 required_unfuseable_inputs = [
683716 inp
684717 for inp in next_node .inputs
685- if next_node in unfuseable_clients_clone .get (inp , () )
718+ if next_node in unfuseable_clients_clone .get (inp )
686719 ]
687720 new_required_unfuseable_inputs = [
688721 inp
@@ -705,7 +738,7 @@ def variables_depend_on(
705738 if not must_backtrack :
706739 implied_unfuseable_clients = {
707740 c
708- for client in unfuseable_clients_clone .get (next_out , () )
741+ for client in unfuseable_clients_clone .get (next_out )
709742 if not isinstance (client .op , Output )
710743 for c in client .outputs
711744 }
@@ -726,13 +759,15 @@ def variables_depend_on(
726759
727760 if must_backtrack :
728761 for inp in next_node .inputs :
729- if (
730- inp .owner in visited_nodes
731- # next_node could have the same input repeated
732- and next_node in fuseable_clients_clone [inp ]
733- ):
734- fuseable_clients_clone [inp ].remove (next_node )
735- unfuseable_clients_clone [inp ].add (next_node )
762+ if inp .owner in visited_nodes :
763+ if next_node not in fuseable_clients_clone [inp ]:
764+ # This can happen when next node has repeated inputs
765+ continue
766+ fuseable_clients_clone .remove_from_key (
767+ inp , next_node
768+ )
769+ unfuseable_clients_clone .add_to_key (inp , next_node )
770+
736771 # This input must become an output of the subgraph,
737772 # because it can't be merged with next_node.
738773 # We will revisit it to make sure this is safe.
@@ -741,8 +776,13 @@ def variables_depend_on(
741776 # need to convert to tuple not to change set size during iteration
742777 for client in tuple (fuseable_clients_clone [next_out ]):
743778 if client in visited_nodes :
744- fuseable_clients_clone [next_out ].remove (client )
745- unfuseable_clients_clone [next_out ].add (client )
779+ fuseable_clients_clone .remove_from_key (
780+ next_out , client
781+ )
782+ unfuseable_clients_clone .add_to_key (
783+ next_out , client
784+ )
785+
746786 # next_out must become an input of the subgraph.
747787 # We will revisit any of its clients currently
748788 # in the subgraph to make sure this is safe.
@@ -785,7 +825,7 @@ def variables_depend_on(
785825 sorted (
786826 (
787827 node
788- for node in fuseable_clients_clone .get (next_out , () )
828+ for node in fuseable_clients_clone .get (next_out )
789829 if node not in visited_nodes
790830 ),
791831 key = toposort_index .get , # type: ignore[arg-type]
0 commit comments