@@ -611,7 +611,8 @@ def _check_inputs_shape(
611611 elif isinstance (input1 , dict ):
612612 if input1 .keys () != input2 .keys ():
613613 return False
614- for (ka , va ), vb in zip (input1 .items (), input2 .values ()):
614+ for ka , va in input1 .items ():
615+ vb = input2 [ka ]
615616 if type (va ) != type (vb ):
616617 return False
617618 if isinstance (va , bool ) and va != vb :
@@ -638,9 +639,9 @@ def _check_inputs_shape(
638639
639640 @staticmethod
640641 def _check_tensor_shapes_with_dynamic_shapes (
641- t1 : torch .tensor , t2 : torch .tensor , dynamic_shape : dict [int , Any ]
642+ input_1 : torch .tensor , input_2 : torch .tensor , dynamic_shape : dict [int , Any ]
642643 ) -> bool :
643- for (i , axis_0 ), axis_1 in zip (enumerate (t1 .shape ), t2 .shape ):
644+ for (i , axis_0 ), axis_1 in zip (enumerate (input_1 .shape ), input_2 .shape ):
644645 if axis_0 != axis_1 :
645646 if i not in dynamic_shape :
646647 logger .warning (
@@ -650,7 +651,7 @@ def _check_tensor_shapes_with_dynamic_shapes(
650651 dyn = dynamic_shape [i ]
651652 if axis_1 > dyn .max or axis_1 < dyn .min :
652653 raise DynamicShapeOutOfRangeException (
653- f"The input size ( { axis_1 } ) of dimension ( { i } ) is not in dynamic shape range [{ dyn .max } , { dyn .max } ]! "
654+ f"Dimension ( { i } ) of new input tensor is not the range of supported shapes (saw: ( { axis_1 } ), expected: [{ dyn .min } , { dyn .max } ]) "
654655 )
655656
656657 return True
0 commit comments