@@ -836,30 +836,57 @@ def remove_identity_nodes(self):
836836 """
837837 Removes identity nodes.
838838 """
839- # f<irst pass: detect replacements
839+ # first pass: detect replacements
840840 new_nodes = []
841841 input_names = set (i .name for i in self .inputs )
842842 output_names = set (i .name for i in self .outputs )
843843 replacements = {}
844+ replacements_rev = {}
844845 for node in self .nodes :
845846 if node .op_type != "Identity" :
846847 new_nodes .append (node )
847848 continue
848849
849850 if node .output [0 ] not in output_names :
850851 old_name , new_name = node .output [0 ], node .input [0 ]
851- elif node .input [0 ] not in input_names :
852+ elif (
853+ node .input [0 ] not in input_names
854+ and node .input [0 ] not in output_names
855+ and node .input [0 ] not in replacements
856+ ):
852857 old_name , new_name = node .input [0 ], node .output [0 ]
853858 else :
854859 new_nodes .append (node )
855860 continue
856861
857862 # the new name can be set for replacements as well
858- assert old_name not in replacements
859863 if new_name in replacements :
860864 new_name = replacements [new_name ]
861- assert new_name not in replacements
865+ assert new_name not in replacements , (
866+ f"Name { old_name !r} still in { replacements } , node.op_type={ node .op_type !r} , "
867+ f"node.input={ node .input } , node.output={ node .output } , "
868+ f"input_names={ input_names } , output_names={ output_names } "
869+ )
870+ if old_name in replacements_rev :
871+ old_old_name = replacements_rev [old_name ]
872+ replacements [old_old_name ] = new_name
873+ replacements_rev [new_name ] = old_old_name
874+ if old_name in replacements :
875+ replacements [replacements [old_name ]] = new_name
876+ assert new_name not in replacements , (
877+ f"Name { old_name !r} still in { replacements } , node.op_type={ node .op_type !r} , "
878+ f"node.input={ node .input } , node.output={ node .output } , "
879+ f"input_names={ input_names } , output_names={ output_names } "
880+ )
862881 replacements [old_name ] = new_name
882+ replacements_rev [new_name ] = old_name
883+
884+ # verification
885+ for k , v in replacements .items ():
886+ assert v not in replacements , (
887+ f"replacement { k } ->{ v } is not possible because of "
888+ f"{ v } ->{ replacements [v ]} , old_name={ old_name !r} , new_name={ new_name !r} "
889+ )
863890
864891 # second pass: replacements in initializer
865892 for k , v in replacements .items ():
@@ -876,10 +903,12 @@ def remove_identity_nodes(self):
876903 repo = {o for o in node .output if o in replacements }
877904 repi = {o for o in node .input if o in replacements }
878905 if repi or repo :
906+ new_inputs = [replacements .get (i , i ) for i in node .input ]
907+ new_outputs = [replacements .get (i , i ) for i in node .output ]
879908 new_node = oh .make_node (
880909 node .op_type ,
881- [ replacements . get ( i , i ) for i in node . input ] ,
882- [ replacements . get ( i , i ) for i in node . output ] ,
910+ new_inputs ,
911+ new_outputs ,
883912 domain = node .domain ,
884913 name = node .name ,
885914 )
0 commit comments