diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2cbecdcfc..32dc4404f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -18,16 +18,21 @@ import torch from onnxscript import ( + BFLOAT16, BOOL, COMPLEX64, COMPLEX128, DOUBLE, FLOAT, + FLOAT16, INT8, INT16, INT32, INT64, UINT8, + UINT16, + UINT32, + UINT64, graph, ir, ) @@ -54,6 +59,7 @@ _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi +Rank = common_ops.Rank @torch_op("aten::_local_scalar_dense", trace_only=True) @@ -71,11 +77,13 @@ def aten__local_scalar_dense(self: TensorType) -> TensorType: @torch_op("aten::_log_softmax", trace_only=True) -def aten__log_softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: +def aten__log_softmax_half( + self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool +) -> FLOAT: """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" self_is_scalar = len(self.shape) == 0 - if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: + if half_to_float: self = op.Cast(self, to=FLOAT.dtype) if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) @@ -85,25 +93,46 @@ def aten__log_softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHigh return result -@torch_op("aten::_softmax", trace_only=True) -def aten__softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: - """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" +@torch_op("aten::_log_softmax", trace_only=True) +def aten__log_softmax( + self: TFloatHighPrecision, + dim: int, + half_to_float: bool, +) -> TFloatHighPrecision: + """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" self_is_scalar = len(self.shape) == 0 - - if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: - self = op.Cast(self, to=FLOAT.dtype) - if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.Softmax(self, axis=dim) + result = op.LogSoftmax(self, axis=dim) if self_is_scalar: - # Convert to scalar when input is scalar result = op.Squeeze(result) - return result +@torch_op("aten::_softmax", trace_only=True) +def aten__softmax_half(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT: + """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" + + # trace_only because we need to cast conditionally based on half_to_float + if half_to_float: + self = op.Cast(self, to=FLOAT.dtype) + + return aten_softmax_no_dtype(self, dim) + + +@torch_op("aten::_softmax", trace_only=True) +def aten__softmax( + self: TFloatHighPrecision, dim: int, half_to_float: bool +) -> TFloatHighPrecision: + """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" + + # trace_only to reuse aten_softmax_no_dtype + + del half_to_float # Unused + return aten_softmax_no_dtype(self, dim) + + @torch_op(("aten::abs", "_operator::abs"), trace_only=True) def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8: """abs(Tensor self) -> Tensor""" @@ -132,35 +161,16 @@ def aten_acosh(self: TFloat) -> TFloat: return op.Acosh(self) -@torch_op("aten::add.Tensor", trace_only=True) -def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: +@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True) +def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - - if self.dtype == ir.DataType.BOOL: - # alpha can also be bool - if alpha == 0: - return op.Identity(self) - return op.Or(self, other) - + # TODO(microsoft/onnxruntime#15977): Improve fp16 precision if alpha != 1.0: alpha = op.CastLike(alpha, other) other = op.Mul(other, alpha) return op.Add(self, other) -@torch_op("aten::add.Scalar", trace_only=True) -def aten_add_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor: - """add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor""" - - other = op.Constant(value=ir.tensor(other, dtype=self.dtype)) - return aten_add(self, other, alpha=alpha) - - -@torch_op("_operator::add", trace_only=True) -def operator_add(self: TTensor, other: TTensor) -> TTensor: - return op.Add(self, other) - - @torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True) def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" @@ -364,6 +374,7 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self +@torch_op("aten::all.dims", trace_only=True) def _aten_all_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: """all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor""" @@ -482,6 +493,7 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self +@torch_op("aten::any.dims", trace_only=True) def _aten_any_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: if len(self.shape) == 0: result = op.Cast(self, to=BOOL.dtype) @@ -721,6 +733,7 @@ def aten_argmax( return result +@torch_op("aten::argmax", private=True, trace_only=True) def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -733,6 +746,7 @@ def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result +@torch_op("aten::argmax", private=True, trace_only=True) def _aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -760,6 +774,7 @@ def aten_argmin( return result +@torch_op("aten::argmin", private=True, trace_only=True) def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -772,6 +787,7 @@ def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result +@torch_op("aten::argmin", private=True, trace_only=True) def _aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -909,21 +925,16 @@ def aten_atan(self: TFloat) -> TFloat: return op.Atan(self) -@torch_op("aten::atan2", trace_only=True) +@torch_op("aten::atan2") def aten_atan2(self: TFloat, other: TFloat) -> TFloat: """atan2(Tensor self, Tensor other) -> Tensor""" # self is y, and other is x on coordinate slope = op.Div(self, other) atan = op.Atan(slope) - zero = common_ops.constant(0.0, dtype=self.dtype) - pi = common_ops.constant(_MATH_PI, dtype=self.dtype) - - second_third_quadrant = op.Where(op.Greater(self, zero), atan + pi, atan - pi) - result = op.Where(op.Less(other, zero), second_third_quadrant, atan) - # Map NaN to 0 to match PyTorch behavior - result = op.Where(op.IsNaN(result), zero, result) + second_third_quadrant = op.Where(self > 0.0, atan + _MATH_PI, atan - _MATH_PI) + result = op.Where(other < 0.0, second_third_quadrant, atan) return result @@ -959,11 +970,11 @@ def reshape_to_1d(tensor): return op.SequenceMap(self, body=reshape_to_1d) -@torch_op("aten::atleast_2d", trace_only=True) +@torch_op("aten::atleast_2d") def aten_atleast_2d(self: TTensor) -> TTensor: """atleast_2d(Tensor self) -> Tensor""" - if len(self.shape) <= 1: + if Rank(self) <= 1: self = op.Reshape(self, op.Constant(value_ints=[1, -1])) return op.Identity(self) @@ -987,7 +998,7 @@ def reshape_to_2d(tensor): def aten_atleast_3d(self: TTensor) -> TTensor: """atleast_3d(Tensor self) -> Tensor""" - rank = len(self.shape) + rank = Rank(self) if rank <= 1: self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1])) elif rank == 2: @@ -1173,7 +1184,6 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor: return op.CastLike(sampled, self) -@torch_op("aten::bilinear", trace_only=True) def aten_bilinear( input1: TensorType, input2: TensorType, @@ -1182,23 +1192,7 @@ def aten_bilinear( ) -> TensorType: """bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor""" - # Bilinear transformation: y = x1^T A x2 + b - # input1 shape: (..., in1_features) - # input2 shape: (..., in2_features) - # weight shape: (out_features, in1_features, in2_features) - # bias shape: (out_features) - optional - # output shape: (..., out_features) - - # Use Einsum to compute the bilinear transformation - # "...i,oij,...j->...o" means: - # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o] - result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o") - - # Add bias if provided - if bias is not None: - result = op.Add(result, bias) - - return result + raise NotImplementedError() def aten_binary_cross_entropy_with_logits( @@ -1232,178 +1226,147 @@ def aten_binomial( @torch_op( ( "aten::bitwise_and.Tensor", + "aten::bitwise_and.Scalar", + "aten::bitwise_and.Scalar_Tensor", "_operator::and_", ), trace_only=True, ) -def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor: +def aten_bitwise_and(self: TInt, other: TInt) -> TInt: """bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor""" + # logical_and implements the BOOL variant - assert self.dtype == other.dtype or self.dtype is None or other.dtype is None - dtype = self.dtype if self.dtype is not None else other.dtype - assert dtype is not None + return op.BitwiseAnd(self, other) - if dtype.is_integer(): - return op.BitwiseAnd(self, other) - if dtype == ir.DataType.BOOL: - return op.And(self, other) - raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") +@torch_op( + ( + "aten::bitwise_left_shift.Tensor", + "aten::bitwise_left_shift.Tensor_Scalar", + "aten::bitwise_left_shift.Scalar_Tensor", + "_operator::__lshift__", + "aten::__lshift__.Scalar", + ), + trace_only=True, +) +def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16: + """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" + # assert other >= 0 + self = op.Cast(self, to=UINT16.dtype) + other = op.Cast(other, to=UINT16.dtype) + + result = op.BitShift(self, other, direction="LEFT") -@torch_op("aten::bitwise_and.Scalar", trace_only=True) -def aten_bitwise_and_scalar(self: TTensor, other: int) -> TTensor: - """bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor""" + return op.Cast(result, to=INT16.dtype) - other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) - return aten_bitwise_and(self, other_tensor) +@torch_op( + ( + "aten::bitwise_left_shift.Tensor", + "aten::bitwise_left_shift.Tensor_Scalar", + "aten::bitwise_left_shift.Scalar_Tensor", + "_operator::__lshift__", + "aten::__lshift__.Scalar", + ), + trace_only=True, +) +def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32: + """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" + # assert other >= 0 + self = op.Cast(self, to=UINT32.dtype) + other = op.Cast(other, to=UINT32.dtype) -@torch_op("aten::bitwise_and.Scalar_Tensor", trace_only=True) -def aten_bitwise_and_scalar_tensor(self: float, other: TTensor) -> TTensor: - """bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + result = op.BitShift(self, other, direction="LEFT") - self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) - return aten_bitwise_and(self_tensor, other) + return op.Cast(result, to=INT32.dtype) @torch_op( ( "aten::bitwise_left_shift.Tensor", + "aten::bitwise_left_shift.Tensor_Scalar", + "aten::bitwise_left_shift.Scalar_Tensor", "_operator::__lshift__", + "aten::__lshift__.Scalar", ), trace_only=True, ) -def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: +def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64: """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - assert self.dtype == other.dtype or self.dtype is None or other.dtype is None - dtype = self.dtype if self.dtype is not None else other.dtype - assert dtype is not None - # assert other >= 0 - if dtype.bitwidth == 8: - unsigned_dtype = ir.DataType.UINT8 - signed_dtype = ir.DataType.INT8 - elif dtype.bitwidth == 16: - unsigned_dtype = ir.DataType.UINT16 - signed_dtype = ir.DataType.INT16 - elif dtype.bitwidth == 32: - unsigned_dtype = ir.DataType.UINT32 - signed_dtype = ir.DataType.INT32 - elif dtype.bitwidth == 64: - unsigned_dtype = ir.DataType.UINT64 - signed_dtype = ir.DataType.INT64 - else: - raise NotImplementedError(f"Not implemented for type {dtype}") - - self = op.Cast(self, to=unsigned_dtype) - other = op.Cast(other, to=unsigned_dtype) + self = op.Cast(self, to=UINT64.dtype) + other = op.Cast(other, to=UINT64.dtype) result = op.BitShift(self, other, direction="LEFT") - return op.Cast(result, to=signed_dtype) + return op.Cast(result, to=INT64.dtype) @torch_op( - ("aten::bitwise_left_shift.Tensor_Scalar", "aten::__lshift__.Scalar"), trace_only=True + ( + "aten::bitwise_left_shift.Tensor", + "aten::bitwise_left_shift.Tensor_Scalar", + "aten::bitwise_left_shift.Scalar_Tensor", + "_operator::__lshift__", + "aten::__lshift__.Scalar", + ), + trace_only=True, ) -def aten_bitwise_left_shift_tensor_scalar(self: TInt, other: int) -> TInt: - """bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor""" - other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) - return aten_bitwise_left_shift(self, other_tensor) +def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8: + """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" + # assert other >= 0 + self = op.Cast(self, to=UINT8.dtype) + other = op.Cast(other, to=UINT8.dtype) + result = op.BitShift(self, other, direction="LEFT") -@torch_op("aten::bitwise_left_shift.Scalar_Tensor", trace_only=True) -def aten_bitwise_left_shift_scalar_tensor(self: int, other: TInt) -> TInt: - """bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" - self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) - return aten_bitwise_left_shift(self_tensor, other) + return op.Cast(result, to=INT8.dtype) @torch_op("aten::bitwise_not", trace_only=True) -def aten_bitwise_not(self: TTensor) -> TTensor: +def aten_bitwise_not(self: TInt) -> TInt: """bitwise_not(Tensor self) -> Tensor""" + # logical_not implements the BOOL variant - if self.dtype == ir.DataType.BOOL: - return op.Not(self) - if self.dtype.is_integer(): - return op.BitwiseNot(self) - raise NotImplementedError(f"Not implemented for type {self.dtype}") + return op.BitwiseNot(self) @torch_op( ( "aten::bitwise_or.Tensor", + "aten::bitwise_or.Scalar", + "aten::bitwise_or.Scalar_Tensor", "_operator::or_", ), trace_only=True, ) -def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor: +def aten_bitwise_or(self: TInt, other: TInt) -> TInt: """bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor""" + # logical_or implements the BOOL variant - assert self.dtype == other.dtype or self.dtype is None or other.dtype is None - dtype = self.dtype if self.dtype is not None else other.dtype - assert dtype is not None - - if dtype.is_integer(): - return op.BitwiseOr(self, other) - if dtype == ir.DataType.BOOL: - return op.Or(self, other) - raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") - - -@torch_op("aten::bitwise_or.Scalar", trace_only=True) -def aten_bitwise_or_scalar(self: TTensor, other: int) -> TTensor: - """bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor""" - other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) - return aten_bitwise_or(self, other_tensor) - - -@torch_op("aten::bitwise_or.Scalar_Tensor", trace_only=True) -def aten_bitwise_or_scalar_tensor(self: int, other: TTensor) -> TTensor: - """bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" - self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) - return aten_bitwise_or(self_tensor, other) + return op.BitwiseOr(self, other) @torch_op( ( "aten::bitwise_right_shift.Tensor", + "aten::bitwise_right_shift.Tensor_Scalar", + "aten::bitwise_right_shift.Scalar_Tensor", "_operator::__rshift__", - ), - trace_only=True, + "aten::__rshift__.Scalar", + ) ) -def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt: +def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - assert self.dtype == other.dtype or self.dtype is None or other.dtype is None - dtype = self.dtype if self.dtype is not None else other.dtype - assert dtype is not None - - if dtype.bitwidth == 8: - unsigned_dtype = ir.DataType.UINT8 - signed_dtype = ir.DataType.INT8 - mask = ir.tensor(0xFF, dtype=unsigned_dtype) - elif dtype.bitwidth == 16: - unsigned_dtype = ir.DataType.UINT16 - signed_dtype = ir.DataType.INT16 - mask = ir.tensor(0xFFFF, dtype=unsigned_dtype) - elif dtype.bitwidth == 32: - unsigned_dtype = ir.DataType.UINT32 - signed_dtype = ir.DataType.INT32 - mask = ir.tensor(0xFFFFFFFF, dtype=unsigned_dtype) - elif dtype.bitwidth == 64: - unsigned_dtype = ir.DataType.UINT64 - signed_dtype = ir.DataType.INT64 - mask = ir.tensor(0xFFFFFFFFFFFFFFFF, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF - else: - raise NotImplementedError(f"Not implemented for type {dtype}") - negative = op.Less(self, 0) - self = op.Cast(self, to=unsigned_dtype) - other = op.Cast(other, to=unsigned_dtype) + self = op.Cast(self, to=UINT16.dtype) + other = op.Cast(other, to=UINT16.dtype) # Simulate arithmetic shift using logical shift # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift(mask, other, direction="RIGHT") + mask = op.BitShift( + op.Cast(op.Constant(value_int=0xFFFF), to=UINT16.dtype), other, direction="RIGHT" + ) mask = op.BitwiseNot(mask) # Do logical shift shifted = op.BitShift(self, other, direction="RIGHT") @@ -1411,53 +1374,119 @@ def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt: negative_shifted = op.BitwiseOr(shifted, mask) # Choose the shifted value based on the sign bit return op.Where( - negative, op.Cast(negative_shifted, to=signed_dtype), op.Cast(shifted, to=signed_dtype) + negative, op.Cast(negative_shifted, to=INT16.dtype), op.Cast(shifted, to=INT16.dtype) ) @torch_op( - ("aten::bitwise_right_shift.Tensor_Scalar", "aten::__rshift__.Scalar"), trace_only=True + ( + "aten::bitwise_right_shift.Tensor", + "aten::bitwise_right_shift.Tensor_Scalar", + "aten::bitwise_right_shift.Scalar_Tensor", + "_operator::__rshift__", + "aten::__rshift__.Scalar", + ) ) -def aten_bitwise_right_shift_tensor_scalar(self: TInt, other: int) -> TInt: - """bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor""" - other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) - return aten_bitwise_right_shift(self, other_tensor) +def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: + """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" + negative = op.Less(self, 0) + self = op.Cast(self, to=UINT32.dtype) + other = op.Cast(other, to=UINT32.dtype) + # Simulate arithmetic shift using logical shift + # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting + mask = op.BitShift( + op.Cast(op.Constant(value_int=0xFFFFFFFF), to=UINT32.dtype), other, direction="RIGHT" + ) + mask = op.BitwiseNot(mask) + # Do logical shift + shifted = op.BitShift(self, other, direction="RIGHT") + # Compute the arithmetic shifted value assuming the sign bit was set + negative_shifted = op.BitwiseOr(shifted, mask) + # Choose the shifted value based on the sign bit + return op.Where( + negative, op.Cast(negative_shifted, to=INT32.dtype), op.Cast(shifted, to=INT32.dtype) + ) -@torch_op("aten::bitwise_right_shift.Scalar_Tensor", trace_only=True) -def aten_bitwise_right_shift_scalar_tensor(self: int, other: TInt) -> TInt: - """bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" - self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) - return aten_bitwise_right_shift(self_tensor, other) +@torch_op( + ( + "aten::bitwise_right_shift.Tensor", + "aten::bitwise_right_shift.Tensor_Scalar", + "aten::bitwise_right_shift.Scalar_Tensor", + "_operator::__rshift__", + "aten::__rshift__.Scalar", + ) +) +def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: + """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" + negative = op.Less(self, 0) + self = op.Cast(self, to=UINT64.dtype) + other = op.Cast(other, to=UINT64.dtype) -@torch_op("aten::bitwise_xor.Tensor", trace_only=True) -def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor: - """bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor""" + # Simulate arithmetic shift using logical shift + # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting + mask = op.BitShift( + # 0xFFFFFFFFFFFFFFFF + op.Cast(op.Constant(value_int=-1), to=UINT64.dtype), + other, + direction="RIGHT", + ) + mask = op.BitwiseNot(mask) + # Do logical shift + shifted = op.BitShift(self, other, direction="RIGHT") + # Compute the arithmetic shifted value assuming the sign bit was set + negative_shifted = op.BitwiseOr(shifted, mask) + # Choose the shifted value based on the sign bit + return op.Where( + negative, op.Cast(negative_shifted, to=INT64.dtype), op.Cast(shifted, to=INT64.dtype) + ) - assert self.dtype == other.dtype or self.dtype is None or other.dtype is None - dtype = self.dtype if self.dtype is not None else other.dtype - assert dtype is not None - if dtype.is_integer(): - return op.BitwiseXor(self, other) - if dtype == ir.DataType.BOOL: - return op.Xor(self, other) - raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") +@torch_op( + ( + "aten::bitwise_right_shift.Tensor", + "aten::bitwise_right_shift.Tensor_Scalar", + "aten::bitwise_right_shift.Scalar_Tensor", + "_operator::__rshift__", + "aten::__rshift__.Scalar", + ) +) +def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: + """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" + negative = op.Less(self, 0) + self = op.Cast(self, to=UINT8.dtype) + other = op.Cast(other, to=UINT8.dtype) + # Simulate arithmetic shift using logical shift + # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting + mask = op.BitShift( + op.Cast(op.Constant(value_int=0xFF), to=UINT8.dtype), other, direction="RIGHT" + ) + mask = op.BitwiseNot(mask) + # Do logical shift + shifted = op.BitShift(self, other, direction="RIGHT") + # Compute the arithmetic shifted value assuming the sign bit was set + negative_shifted = op.BitwiseOr(shifted, mask) + # Choose the shifted value based on the sign bit + return op.Where( + negative, op.Cast(negative_shifted, to=INT8.dtype), op.Cast(shifted, to=INT8.dtype) + ) -@torch_op("aten::bitwise_xor.Scalar", trace_only=True) -def aten_bitwise_xor_scalar(self: TTensor, other: int) -> TTensor: - """bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor""" - other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) - return aten_bitwise_xor(self, other_tensor) +@torch_op( + ( + "aten::bitwise_xor.Tensor", + "aten::bitwise_xor.Scalar", + "aten::bitwise_xor.Scalar_Tensor", + ), + trace_only=True, +) +def aten_bitwise_xor(self: TInt, other: TInt) -> TInt: + """bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor""" + # logical_xor implements the BOOL variant -@torch_op("aten::bitwise_xor.Scalar_Tensor", trace_only=True) -def aten_bitwise_xor_scalar_tensor(self: int, other: TTensor) -> TTensor: - """bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" - self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) - return aten_bitwise_xor(self_tensor, other) + return op.BitwiseXor(self, other) @torch_op("aten::blackman_window", trace_only=True) @@ -1494,10 +1523,10 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::broadcast_to", trace_only=True) -def aten_broadcast_to(self: TTensor, size: Sequence[INT64]) -> TTensor: +@torch_op("aten::broadcast_to") +def aten_broadcast_to(self: TTensor, size: INT64) -> TTensor: """broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = common_ops.merge_dims(size) + return op.Expand(self, size) @@ -1521,7 +1550,7 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op(("aten::cat", "aten::concat", "aten::concatenate"), trace_only=True, complex=True) +@torch_op("aten::cat", trace_only=True, complex=True) def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: """cat(Tensor[] tensors, int dim=0) -> Tensor""" # Real representation unsqueezes the last dimension @@ -1534,18 +1563,8 @@ def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: """cat(Tensor[] tensors, int dim=0) -> Tensor""" - filtered_tensors = [] - for tensor in tensors: - # Remove None tensors - if tensor is None: - continue - # Remove empty tensors - if tensor.shape == (0,): - continue - filtered_tensors.append(tensor) - assert filtered_tensors, "aten::cat received all None or empty tensors" - if len(filtered_tensors) == 1: - return op.Identity(filtered_tensors[0]) + # Remove None tensors + tensors = [tensor for tensor in tensors if tensor is not None] return op.Concat(*tensors, axis=dim) @@ -1842,21 +1861,39 @@ def aten_conj_physical(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::constant_pad_nd", trace_only=True) -def aten_constant_pad_nd(self: TTensor, pad: Sequence[INT64], value: float = 0.0) -> TTensor: +@torch_op("aten::constant_pad_nd") +def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTensor: """constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor""" # The desired order of paddings is # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. # n is the dimension of input. # assume zero-dimensions in the beginning - rank = len(self.shape) - paddings = list(pad) + [0] * (rank * 2 - len(pad)) + # rank = len(self.shape) # rank must be scalar + # paddings = list(pad[:]) + [0] * (rank * 2 - len(pad)) # reverse order and collate first beginnings and then ends - paddings = paddings[-2::-2] + paddings[-1::-2] - constant_value = op.Constant(value=ir.tensor(value, dtype=self.dtype)) + # paddings = paddings[-2::-2] + paddings[-1::-2] + + neg_1 = op.Constant(value_ints=[-1]) + + zero_count = op.Sub(op.Mul(Rank(self), 2), op.Size(pad)) + zero_count = op.Reshape(zero_count, neg_1) + zero = op.Constant(value_ints=[0]) + zeros = op.Expand(zero, zero_count) + torch_paddings = op.Concat(pad, zeros, axis=0) + size_d = op.Size(torch_paddings) + steps = op.Constant(value_ints=[-2]) + + starts = steps + ends = op.Sub(starts, size_d) + odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps) - return op.Pad(self, paddings, constant_value) + starts = neg_1 + ends = op.Sub(starts, size_d) + even_elements = op.Slice(torch_paddings, starts, ends, zero, steps) + + onnx_padding = op.Concat(odd_elements, even_elements, axis=0) + return op.Pad(self, onnx_padding, value) @torch_op("aten::contiguous", trace_only=True) @@ -2091,6 +2128,7 @@ def aten_convolution( return result +@torch_op("aten::convolution", private=True, trace_only=True) def _aten_convolution_onnx( input: TFloat, weight: TFloat, @@ -2193,7 +2231,7 @@ def aten_convolution_overrideable( raise NotImplementedError() -@torch_op("aten::copy", trace_only=True) +@torch_op("aten::copy") def aten_copy( self: TTensor, src: TTensor2, @@ -2562,11 +2600,9 @@ def aten_diagflat(self: TensorType, offset: int = 0) -> TensorType: @torch_op(("aten::diagonal", "aten::diagonal_copy"), trace_only=True) -def aten_diagonal(self: TTensor, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TTensor: +def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TReal: """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" - is_bool = self.dtype == BOOL.dtype - # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 # [0,1,2] -> [1,0,2] when dim1=0 and dim2=2 @@ -2592,16 +2628,9 @@ def aten_diagonal(self: TTensor, offset: int = 0, dim1: int = 0, dim2: int = 1) dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) - - if is_bool: - self_int = op.Cast(self, to=INT64.dtype) - mask_int = op.Cast(mask, to=INT64.dtype) - self_int_t = op.Transpose(self_int, perm=perm) - result = op.Mul(self_int_t, mask_int) - else: - mask = op.CastLike(mask, self) - self_t = op.Transpose(self, perm=perm) - result = op.Mul(self_t, mask) + mask = op.CastLike(mask, self) + self_t = op.Transpose(self, perm=perm) + result = op.Mul(self_t, mask) result = op.ReduceSum(result, keepdims=False, axes=axes) # min(row, col) min_dim_size = op.Min(dim1_size, dim2_size) @@ -2639,8 +2668,79 @@ def aten_diagonal(self: TTensor, offset: int = 0, dim1: int = 0, dim2: int = 1) end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) - if is_bool: - result = op.Cast(result, to=BOOL.dtype) + return result + + +@torch_op("aten::diagonal", trace_only=True) +def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1) -> BOOL: + """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" + + # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims + # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 + # [0,1,2] -> [1,0,2] when dim1=0 and dim2=2 + # [0,1,2] -> [0,1,2] when dim1=1 and dim2=2 + if dim1 < 0: + dim1 = dim1 + len(self.shape) + if dim2 < 0: + dim2 = dim2 + len(self.shape) + + self_rank = len(self.shape) + perm = list(range(self_rank)) + perm.remove(dim1) + perm.remove(dim2) + perm.append(dim1) + perm.append(dim2) + + # If rank=2, then axes=[0]; if rank=3, then axes=[1] + # This is because computing diagonal sum is on dim2 after transpose by perm + axes = [self_rank - 2] + + neg_1 = op.Constant(value_ints=[-1]) + dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row + dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col + mask_shape = op.Concat(dim1_size, dim2_size, axis=0) + mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) + self_int = op.Cast(self, to=INT64.dtype) + mask_int = op.Cast(mask, to=INT64.dtype) + self_int_t = op.Transpose(self_int, perm=perm) + result = op.Mul(self_int_t, mask_int) + result = op.ReduceSum(result, keepdims=False, axes=axes) + # min(row, col) + min_dim_size = op.Min(dim1_size, dim2_size) + # take 2 tensors as example: + # one is 3x5 in size, min_dim_size = 3, dim1_size = 3 + # the other is 5x3 in size, min_dim_size = 3, dim1_size = 5 + # 3 rows x 5 cols 5 rows x 3 cols + # offset diagonal offset diagonal + # ---------------- ---------------- + # -4 0 -6 0 + # -3 0 -5 0 + # -2 1 -4 1 + # -1 2 -3 2 + # 0 3 -2 3 + # 1 3 -1 3 + # 2 3 0 3 + # 3 2 1 2 + # 4 1 2 1 + # 5 0 3 0 + # 6 0 4 0 + + # From above table, we can get the logic below + offset_val = op.Constant(value_ints=[offset]) + if offset < 0: + # row + offset + length = op.Add(dim1_size, offset_val) + start = op.Constant(value_ints=[0]) + else: # offset >= 0 + # col - offset + length = op.Sub(dim2_size, offset_val) + start = offset_val + + # max(min(length, min(row, col)), 0) + length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) + end = op.Add(start, length) + result = op.Slice(result, start, end, axes=axes) + result = op.Cast(result, to=BOOL.dtype) return result @@ -2751,37 +2851,45 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat: @torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) -def aten_div_mode(self: TReal, other: TReal, rounding_mode: Optional[str] = None) -> TReal: +def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: Optional[str] = None) -> TFloat: """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" assert rounding_mode in {"trunc", "floor", None} - if self.dtype.is_integer(): - quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) + if rounding_mode == "trunc": + # Rounds the results of the division towards zero. + # Equivalent to C-style integer division + return aten_trunc(op.Div(self, other)) + if rounding_mode == "floor": + return op.Floor(op.Div(self, other)) + + return op.Div(self, other) - if rounding_mode == "trunc": - # Rounds the results of the division towards zero. - # Equivalent to C-style integer division - result = aten_trunc(quotient) - return op.CastLike(result, self) - if rounding_mode == "floor": - result = op.Floor(quotient) - return op.CastLike(result, self) - assert rounding_mode is None - # When rounding_mode is None, the return type is float32 - return quotient +@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) +def aten_div_mode_int( + self: TInt, other: TInt, rounding_mode: Optional[str] = None +) -> TensorType: + """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + + Variant for integer inputs. + """ + assert rounding_mode in {"trunc", "floor", None} - # Float inputs + quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) if rounding_mode == "trunc": # Rounds the results of the division towards zero. # Equivalent to C-style integer division - return aten_trunc(op.Div(self, other)) + result = aten_trunc(quotient) + return op.CastLike(result, self) if rounding_mode == "floor": - return op.Floor(op.Div(self, other)) + result = op.Floor(quotient) + return op.CastLike(result, self) - return op.Div(self, other) + assert rounding_mode is None + # When rounding_mode is None, the return type is float32 + return quotient @torch_op("aten::dot", trace_only=True) @@ -2993,27 +3101,42 @@ def aten_embedding_bag_padding_idx( sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, - padding_idx: int = -1, + padding_idx: Optional[int] = None, ) -> Tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) We add default values for the attributes to accommodate _embedding_bag as well: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) """ - assert padding_idx is not None, ( - "padding_idx must not be None. This is likely a dispatcher error" - ) if per_sample_weights is None: per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices)) per_sample_weights = op.CastLike(per_sample_weights, weight) - # Change padding_idx to positive value, -1 means the last index - if padding_idx < 0: - padding_idx = weight.shape[0] + padding_idx + if padding_idx is not None: + if padding_idx < 0: + padding_idx = weight.shape[0] + padding_idx + # Call the existing function for handling padding_idx + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( + weight, + indices, + offsets, + mode, + per_sample_weights, + include_last_offset, + padding_idx, + ) - result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( - weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx + return result, offset2bag, bag_size, max_indices + + # When padding_idx is None, use the standard embedding_bag implementation + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx( + weight, + indices, + offsets, + mode, + per_sample_weights, + include_last_offset, ) return result, offset2bag, bag_size, max_indices @@ -3178,20 +3301,20 @@ def aten_embedding_sparse_backward( @torch_op("aten::empty.memory_format", trace_only=True) def aten_empty( - size: Sequence[INT64], + size: IntType, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", pin_memory: bool = False, memory_format: str = "", ) -> TensorType: # type: ignore[type-var] - """empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor if dtype == -1: dtype = FLOAT.dtype - - # using Zeros to simulate empty() - zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) - size = common_ops.merge_dims(size) + # using Zeros to simulate np.empty() + size = op.Cast(size, to=INT64.dtype) + zero = op.Constant(value_float=0.0) + zero = op.Cast(zero, to=dtype) return op.Expand(zero, size) @@ -3226,18 +3349,17 @@ def aten_empty_quantized( @torch_op("aten::empty_strided", trace_only=True) def aten_empty_strided( - size: Sequence[INT64], + size: INT64, stride: INT64, layout: str = "", - dtype: int = FLOAT.dtype, device: str = "", pin_memory: bool = False, ) -> TTensor: # type: ignore[type-var] # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # using Zeros to simulate empty() - zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) - size = common_ops.merge_dims(size) + size = op.Cast(size, to=INT64.dtype) + zero = op.Constant(value_float=0.0) return op.Expand(zero, size) @@ -3285,14 +3407,13 @@ def aten_exp2(self: TFloat) -> TFloat: @torch_op("aten::expand", trace_only=True) -def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> TTensor: +def aten_expand(self: TTensor, size: TInt, implicit: bool = False) -> TTensor: """expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)""" + size = op.Cast(size, to=INT64.dtype) # NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1. # To support -1 dim, we need to convert -1 to 1. - # Even though in theory a dynamic dim can still be -1, in practice it is very unlikely - # and isn't expected to appear from correct usages of SymInt. - size = [1 if isinstance(s, int) and s == -1 else s for s in size] - return op.Expand(self, common_ops.merge_dims(size)) + size = op.Abs(size) + return op.Expand(self, size) @torch_op("aten::expand_as", trace_only=True) @@ -3552,27 +3673,23 @@ def python_math_floor(self: TFloat) -> TInt: @torch_op("aten::floor_divide", trace_only=True) -def aten_floor_divide(self: TTensor, other: TTensor) -> TTensor: +def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: """floor_divide(Tensor self, Tensor other) -> Tensor""" - if self.dtype.is_floating_point(): - return op.Floor(op.Div(self, other)) + return op.Floor(op.Div(self, other)) - assert self.dtype.is_integer() - if not self.dtype.is_signed(): - return op.Div(self, other) +@torch_op("aten::floor_divide", trace_only=True) +def aten_floor_divide_int(self: TInt, other: TInt) -> TInt: + """floor_divide(Tensor self, Tensor other) -> Tensor""" - # Convert truncation to flooring - # Reference: https://github.com/pytorch/pytorch/blob/ffc645c870f0abd368606ba1e2b3b58cacb03046/torch/_refs/__init__.py#L1401C1-L1409C70 - # offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) - # return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) - offset = op.And( - op.Not(op.Equal(op.Sign(self), op.Sign(other))), - op.Cast(op.Mod(self, other), to=BOOL.dtype), - ) - offset = op.Cast(offset, to=self.dtype) - return op.Sub(op.Div(self, other), offset) + # TODO(justinchuby): This can be simplified if we can constrain the + # inputs to be positive integers. Consider how we can embed constraints in the model. + dtype = self.dtype + self = op.Cast(self, to=FLOAT.dtype) + other = op.Cast(other, to=FLOAT.dtype) + result = op.Floor(op.Div(self, other)) + return op.Cast(result, to=dtype) @torch_op("_operator::floordiv", trace_only=True) @@ -3705,15 +3822,11 @@ def aten_gather( else: return op.Expand(self, op.Shape(index)) - is_scalar_index = len(index.shape) == 0 - if is_scalar_index: - index = op.Unsqueeze(index, [0]) + if len(index.shape) == 0: + return op.Identity(self) index = op.Cast(index, to=INT64.dtype) result = op.GatherElements(self, index, axis=dim) - - if is_scalar_index: - result = op.Squeeze(result, [0]) return result @@ -3732,27 +3845,29 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: @torch_op( - ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor"), + ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), trace_only=True, ) -def aten_ge(self: TTensor, other: TTensor) -> BOOL: +def aten_ge(self: TReal, other: TReal) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" - if self.dtype == ir.DataType.BOOL: - # self, other, self >= other - # F, F, T - # F, T, F - # T, F, T - # T, T, T - return op.Or(self, op.Not(other)) - return op.GreaterOrEqual(self, other) -@torch_op("_operator::ge", trace_only=True) -def operator_ge(self: TTensor, other: TTensor) -> BOOL: - # operator.ge for SymInt - return op.GreaterOrEqual(self, other) +@torch_op( + ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), + trace_only=True, +) +def aten_ge_bool(self: BOOL, other: BOOL) -> BOOL: + """ge.Tensor(Tensor self, Tensor other) -> Tensor""" + + # self, other, self >= other + # F, F, T + # F, T, F + # T, F, T + # T, T, T + + return op.Or(self, op.Not(other)) def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]: @@ -3767,192 +3882,6 @@ def aten_ger(self: TensorType, vec2: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::gru.input", trace_only=True) -def aten_gru( - input: TFloat, - hx: TFloat, - params: Sequence[TFloat], - has_biases: bool, - num_layers: int, - dropout: float, - train: bool, - bidirectional: bool, - batch_first: bool, -) -> tuple[TFloat, TFloat]: - """gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)""" - - # Determine number of directions - num_directions = 2 if bidirectional else 1 - - # Get dimensions - if batch_first: - # Convert from [batch, seq, input_size] to [seq, batch, input_size] - input = op.Transpose(input, perm=[1, 0, 2]) - - hidden_size = op.Shape(hx, start=2, end=3) - - # Process each layer - current_input = input - output_h_list = [] - - for layer_idx in range(num_layers): - # Extract hidden state for this layer - layer_start = layer_idx * num_directions - layer_end = (layer_idx + 1) * num_directions - layer_h = op.Slice(hx, layer_start, layer_end, axes=[0]) - - # Extract parameters for this layer - # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction - params_per_direction = 4 if has_biases else 2 - params_per_layer = params_per_direction * num_directions - param_start_idx = layer_idx * params_per_layer - - # Build weight matrices for ONNX GRU - # ONNX expects: W[zrh] shape [num_directions, 3*hidden_size, input_size] - # PyTorch provides: W_ih shape [3*hidden_size, input_size] - W_list = [] - R_list = [] - B_list = [] if has_biases else None - - for dir_idx in range(num_directions): - dir_param_start = param_start_idx + dir_idx * params_per_direction - W_ih = params[ - dir_param_start - ] # [3*hidden_size, input_size] - PyTorch order: [r,z,n] - W_hh = params[ - dir_param_start + 1 - ] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n] - - # Reorder gates from PyTorch [r,z,n] to ONNX [z,r,n] - # Split into individual gates - W_ir = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) - W_iz = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - W_in = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - - W_hr = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) - W_hz = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - W_hn = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - - # Reorder: [z,r,n] - W_ih_reordered = op.Concat( - W_iz, W_ir, W_in, axis=0 - ) # [3*hidden_size, input_size] - ONNX order - W_hh_reordered = op.Concat( - W_hz, W_hr, W_hn, axis=0 - ) # [3*hidden_size, hidden_size] - ONNX order - - # Add direction dimension - W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 3*hidden_size, input_size] - W_hh_expanded = op.Unsqueeze( - W_hh_reordered, [0] - ) # [1, 3*hidden_size, hidden_size] - - W_list.append(W_ih_expanded) - R_list.append(W_hh_expanded) - - if has_biases: - b_ih = params[dir_param_start + 2] # [3*hidden_size] - PyTorch order: [r,z,n] - b_hh = params[dir_param_start + 3] # [3*hidden_size] - PyTorch order: [r,z,n] - - # Reorder biases from PyTorch [r,z,n] to ONNX [z,r,n] - b_ir = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) - b_iz = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - b_in = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - - b_hr = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) - b_hz = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - b_hn = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - - # Reorder: [z,r,n] - b_ih_reordered = op.Concat( - b_iz, b_ir, b_in, axis=0 - ) # [3*hidden_size] - ONNX order - b_hh_reordered = op.Concat( - b_hz, b_hr, b_hn, axis=0 - ) # [3*hidden_size] - ONNX order - - # ONNX expects biases concatenated: [Wb[zrh], Rb[zrh]] - b_combined = op.Concat( - b_ih_reordered, b_hh_reordered, axis=0 - ) # [6*hidden_size] - b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 6*hidden_size] - B_list.append(b_expanded) - - # Concatenate weights for all directions - W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] - R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] - B = ( - op.Concat(*B_list, axis=0) - if has_biases and len(B_list) > 1 - else (B_list[0] if has_biases else None) - ) - - # Call ONNX GRU operator - direction = "bidirectional" if bidirectional else "forward" - - # Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size] - hidden_size_attr = hx.shape[2] - - if B is not None: - Y, Y_h = op.GRU( - current_input, - W, - R, - B, - initial_h=layer_h, - direction=direction, - hidden_size=hidden_size_attr, - ) - else: - Y, Y_h = op.GRU( - current_input, - W, - R, - initial_h=layer_h, - direction=direction, - hidden_size=hidden_size_attr, - ) - - # Y shape: [seq_length, num_directions, batch_size, hidden_size] - # Reshape to [seq_length, batch_size, num_directions * hidden_size] - Y = op.Transpose( - Y, perm=[0, 2, 1, 3] - ) # [seq_length, batch_size, num_directions, hidden_size] - Y_shape = op.Shape(Y) - new_shape = op.Concat( - op.Slice(Y_shape, [0], [1]), # seq_length - op.Slice(Y_shape, [1], [2]), # batch_size - op.Reshape( - op.Mul( - op.Slice(Y_shape, [2], [3]), # num_directions - op.Slice(Y_shape, [3], [4]), # hidden_size - ), - op.Constant(value_ints=[-1]), - ), - axis=0, - ) - current_input = op.Reshape(Y, new_shape) - - # Apply dropout if not last layer and dropout > 0 - if layer_idx < num_layers - 1 and dropout > 0.0 and train: - current_input, _ = op.Dropout(current_input, dropout, train) - - # Store final hidden state - output_h_list.append(Y_h) - - # Concatenate all layer outputs - final_h = ( - output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) - ) - - # Handle batch_first for output - if batch_first: - # Convert from [seq, batch, features] to [batch, seq, features] - current_input = op.Transpose(current_input, perm=[1, 0, 2]) - - return current_input, final_h - - @torch_op(("_operator::getitem", "aten::getitem")) def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor: return op.SequenceAt(self, i) @@ -4064,28 +3993,28 @@ def aten_gru_cell( @torch_op( - ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor"), + ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), trace_only=True, ) -def aten_gt(self: TTensor, other: TTensor) -> BOOL: +def aten_gt(self: TReal, other: TReal) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" - if self.dtype == ir.DataType.BOOL: - # self, other, self > other - # F, F, F - # F, T, F - # T, F, T - # T, T, F - - return op.And(self, op.Not(other)) - return op.Greater(self, other) -@torch_op("_operator::gt", trace_only=True) -def operator_gt(self: TTensor, other: TTensor) -> BOOL: - # operator.gt for SymInt - return op.Greater(self, other) +@torch_op( + ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), + trace_only=True, +) +def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL: + """gt.Tensor(Tensor self, Tensor other) -> Tensor""" + # self, other, self > other + # F, F, F + # F, T, F + # T, F, T + # T, T, F + + return op.And(self, op.Not(other)) @torch_op("aten::hamming_window", trace_only=True) @@ -4198,7 +4127,7 @@ def reshape_to_atleast_2d(tensor): result = op.ConcatFromSequence(tensors_atleast_2d, axis=1, new_axis=0) # hstack expects a non-empty sequence of tensors. So we don't need to check for length - rank_1d_or_less = op.Less(op.Size(op.Shape(op.SequenceAt(tensors, 0))), 2) + rank_1d_or_less = op.Less(Rank(op.SequenceAt(tensors, 0)), 2) if rank_1d_or_less: result = op.Reshape(result, op.Constant(value_ints=[-1])) return result @@ -4903,28 +4832,29 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: @torch_op( - ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor"), + ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), trace_only=True, ) -def aten_le(self: TTensor, other: TTensor) -> BOOL: +def aten_le(self: TReal, other: TReal) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" - if self.dtype == ir.DataType.BOOL: - # self, other, self <= other - # F, F, T - # F, T, T - # T, F, F - # T, T, T + return op.LessOrEqual(self, other) - return op.Or(other, op.Not(self)) - return op.LessOrEqual(self, other) +@torch_op( + ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), + trace_only=True, +) +def aten_le_bool(self: BOOL, other: BOOL) -> BOOL: + """le.Tensor(Tensor self, Tensor other) -> Tensor""" + # self, other, self <= other + # F, F, T + # F, T, T + # T, F, F + # T, T, T -@torch_op("_operator::le", trace_only=True) -def operator_le(self: TTensor, other: TTensor) -> BOOL: - # operator.le for SymInt - return op.LessOrEqual(self, other) + return op.Or(other, op.Not(self)) @torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar")) @@ -5084,65 +5014,83 @@ def aten_logdet(self: TFloat) -> TFloat: return op.Log(op.Det(self)) -@torch_op("aten::logical_and", trace_only=True) -def aten_logical_and(self: TTensor, other: TTensor) -> BOOL: +@torch_op( + ( + "aten::logical_and", + "aten::bitwise_and.Tensor", + "aten::bitwise_and.Scalar", + "aten::bitwise_and.Scalar_Tensor", + ), + trace_only=True, +) +def aten_logical_and(self: BOOL, other: BOOL) -> BOOL: """logical_and(Tensor self, Tensor other) -> Tensor""" - assert self.dtype == other.dtype - - if self.dtype == ir.DataType.BOOL: - return op.And(self, other) - return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) + return op.And(self, other) -@torch_op("aten::logical_not", trace_only=True) -def aten_logical_not(self: TTensor) -> BOOL: +@torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True) +def aten_logical_not(self: BOOL) -> BOOL: """logical_not(Tensor self) -> Tensor""" - if self.dtype == ir.DataType.BOOL: - return op.Not(self) - return op.Not(op.Cast(self, to=BOOL.dtype)) + return op.Not(self) -@torch_op("aten::logical_or", trace_only=True) -def aten_logical_or(self: TTensor, other: TTensor) -> BOOL: +@torch_op( + ( + "aten::logical_or", + "aten::bitwise_or.Tensor", + "aten::bitwise_or.Scalar", + "aten::bitwise_or.Scalar_Tensor", + "aten::add.Tensor", + "aten::add.Scalar", + ), + trace_only=True, +) +def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: """logical_or(Tensor self, Tensor other) -> Tensor""" - assert self.dtype == other.dtype + return op.Or(self, other) - if self.dtype == ir.DataType.BOOL: - return op.Or(self, other) - return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) - -@torch_op("aten::logical_xor", trace_only=True) -def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL: +@torch_op( + ( + "aten::logical_xor", + "aten::bitwise_xor.Tensor", + "aten::bitwise_xor.Scalar", + "aten::bitwise_xor.Scalar_Tensor", + ), + trace_only=True, +) +def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: """logical_xor(Tensor self, Tensor other) -> Tensor""" - assert self.dtype == other.dtype + return op.Xor(self, other) - if self.dtype == ir.DataType.BOOL: - return op.Xor(self, other) - return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) +@torch_op("aten::logit", private=True) +def _aten_logit_onnx(self: TFloat) -> TFloat: + return op.Log(op.Div(self, op.Sub(1.0, self))) -@torch_op("aten::logit", trace_only=True) -def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: - """logit(Tensor self, float? eps=None) -> Tensor""" - one = ir.tensor(1, dtype=self.dtype) - - if eps is None: - return op.Log(op.Div(self, op.Sub(one, self))) - - one_minus_eps = ir.tensor(1 - eps, dtype=self.dtype) - eps = ir.tensor(eps, dtype=self.dtype) - temporary_self = op.Where(self <= one_minus_eps, self, one_minus_eps) +@torch_op("aten::logit", private=True) +def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat: + eps = op.CastLike(eps, self) + one = op.CastLike(1.0, self) + temporary_self = op.Where(self <= one - eps, self, one - eps) z = op.Where(temporary_self < eps, eps, temporary_self) return op.Log(op.Div(z, op.Sub(one, z))) +@torch_op("aten::logit", trace_only=True) +def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: + """logit(Tensor self, float? eps=None) -> Tensor""" + if eps is None: + return _aten_logit_onnx(self) + return _aten_logit_clamp_onnx(self, eps) + + def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> TensorType: """logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -5195,234 +5143,30 @@ def aten_lstm_mps_backward( raise NotImplementedError() -@torch_op("aten::lstm.input", trace_only=True) -def aten_lstm( - input: TFloat, - hx: Sequence[TFloat], - params: Sequence[TFloat], - has_biases: bool, - num_layers: int, - dropout: float, - train: bool, - bidirectional: bool, - batch_first: bool, -) -> tuple[TFloat, TFloat, TFloat]: - """lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)""" - - # Extract initial hidden and cell states - initial_h = hx[0] # Shape: [num_directions * num_layers, batch_size, hidden_size] - initial_c = hx[1] # Shape: [num_directions * num_layers, batch_size, hidden_size] - - # Determine number of directions - num_directions = 2 if bidirectional else 1 - - # Get dimensions - if batch_first: - # Convert from [batch, seq, input_size] to [seq, batch, input_size] - input = op.Transpose(input, perm=[1, 0, 2]) - - hidden_size = op.Shape(initial_h, start=2, end=3) - - # Process each layer - current_input = input - output_h_list = [] - output_c_list = [] - - for layer_idx in range(num_layers): - # Extract hidden and cell states for this layer - layer_start = layer_idx * num_directions - layer_end = (layer_idx + 1) * num_directions - layer_h = op.Slice(initial_h, layer_start, layer_end, axes=[0]) - layer_c = op.Slice(initial_c, layer_start, layer_end, axes=[0]) - - # Extract parameters for this layer - # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction - params_per_direction = 4 if has_biases else 2 - params_per_layer = params_per_direction * num_directions - param_start_idx = layer_idx * params_per_layer - - # Build weight matrices for ONNX LSTM - # ONNX expects: W[iofc] shape [num_directions, 4*hidden_size, input_size] - # PyTorch provides: W_ih shape [4*hidden_size, input_size] - W_list = [] - R_list = [] - B_list = [] if has_biases else None - - for dir_idx in range(num_directions): - dir_param_start = param_start_idx + dir_idx * params_per_direction - W_ih = params[ - dir_param_start - ] # [4*hidden_size, input_size] - PyTorch order: [i,f,g,o] - W_hh = params[ - dir_param_start + 1 - ] # [4*hidden_size, hidden_size] - PyTorch order: [i,f,g,o] - - # Reorder gates from PyTorch [i,f,g,o] to ONNX [i,o,f,g] - # Split into individual gates - W_ii = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) - W_if = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - W_ig = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - W_io = op.Slice(W_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) - - W_hi = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) - W_hf = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - W_hg = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - W_ho = op.Slice(W_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) - - # Reorder: [i,o,f,g] - W_ih_reordered = op.Concat( - W_ii, W_io, W_if, W_ig, axis=0 - ) # [4*hidden_size, input_size] - ONNX order - W_hh_reordered = op.Concat( - W_hi, W_ho, W_hf, W_hg, axis=0 - ) # [4*hidden_size, hidden_size] - ONNX order - - # Add direction dimension - W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 4*hidden_size, input_size] - W_hh_expanded = op.Unsqueeze( - W_hh_reordered, [0] - ) # [1, 4*hidden_size, hidden_size] - - W_list.append(W_ih_expanded) - R_list.append(W_hh_expanded) - - if has_biases: - b_ih = params[ - dir_param_start + 2 - ] # [4*hidden_size] - PyTorch order: [i,f,g,o] - b_hh = params[ - dir_param_start + 3 - ] # [4*hidden_size] - PyTorch order: [i,f,g,o] - - # Reorder biases from PyTorch [i,f,g,o] to ONNX [i,o,f,g] - b_ii = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) - b_if = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - b_ig = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - b_io = op.Slice(b_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) - - b_hi = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) - b_hf = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - b_hg = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - b_ho = op.Slice(b_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) - - # Reorder: [i,o,f,g] - b_ih_reordered = op.Concat( - b_ii, b_io, b_if, b_ig, axis=0 - ) # [4*hidden_size] - ONNX order - b_hh_reordered = op.Concat( - b_hi, b_ho, b_hf, b_hg, axis=0 - ) # [4*hidden_size] - ONNX order - - # ONNX expects biases concatenated: [Wb[iofc], Rb[iofc]] - b_combined = op.Concat( - b_ih_reordered, b_hh_reordered, axis=0 - ) # [8*hidden_size] - b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 8*hidden_size] - B_list.append(b_expanded) - - # Concatenate weights for all directions - W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] - R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] - B = ( - op.Concat(*B_list, axis=0) - if has_biases and len(B_list) > 1 - else (B_list[0] if has_biases else None) - ) - - # Call ONNX LSTM operator - direction = "bidirectional" if bidirectional else "forward" - - # Extract hidden_size from initial_h shape: [num_layers * num_directions, batch, hidden_size] - hidden_size_attr = initial_h.shape[2] - - if B is not None: - Y, Y_h, Y_c = op.LSTM( - current_input, - W, - R, - B, - initial_h=layer_h, - initial_c=layer_c, - direction=direction, - hidden_size=hidden_size_attr, - ) - else: - Y, Y_h, Y_c = op.LSTM( - current_input, - W, - R, - initial_h=layer_h, - initial_c=layer_c, - direction=direction, - hidden_size=hidden_size_attr, - ) - - # Y shape: [seq_length, num_directions, batch_size, hidden_size] - # Reshape to [seq_length, batch_size, num_directions * hidden_size] - Y = op.Transpose( - Y, perm=[0, 2, 1, 3] - ) # [seq_length, batch_size, num_directions, hidden_size] - Y_shape = op.Shape(Y) - new_shape = op.Concat( - op.Slice(Y_shape, [0], [1]), # seq_length - op.Slice(Y_shape, [1], [2]), # batch_size - op.Reshape( - op.Mul( - op.Slice(Y_shape, [2], [3]), # num_directions - op.Slice(Y_shape, [3], [4]), # hidden_size - ), - op.Constant(value_ints=[-1]), - ), - axis=0, - ) - current_input = op.Reshape(Y, new_shape) - - # Apply dropout if not last layer and dropout > 0 - if layer_idx < num_layers - 1 and dropout > 0.0 and train: - current_input, _ = op.Dropout(current_input, dropout, train) - - # Store final hidden and cell states - output_h_list.append(Y_h) - output_c_list.append(Y_c) - - # Concatenate all layer outputs - final_h = ( - output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) - ) - final_c = ( - output_c_list[0] if len(output_c_list) == 1 else op.Concat(*output_c_list, axis=0) - ) - - # Handle batch_first for output - if batch_first: - # Convert from [seq, batch, features] to [batch, seq, features] - current_input = op.Transpose(current_input, perm=[1, 0, 2]) +@torch_op( + ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), + trace_only=True, +) +def aten_lt(self: TReal, other: TReal) -> BOOL: + """lt.Tensor(Tensor self, Tensor other) -> Tensor""" - return current_input, final_h, final_c + return op.Less(self, other) @torch_op( - ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor"), + ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), trace_only=True, ) -def aten_lt(self: TTensor, other: TTensor) -> BOOL: +def aten_lt_bool(self: BOOL, other: BOOL) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" - if self.dtype == ir.DataType.BOOL: - # self, other, self < other - # F, F, F - # F, T, T - # T, F, F - # T, T, F - return op.And(other, op.Not(self)) - - return op.Less(self, other) - + # self, other, self < other + # F, F, F + # F, T, T + # T, F, F + # T, T, F -@torch_op("_operator::lt", trace_only=True) -def operator_lt(self: TTensor, other: TTensor) -> BOOL: - # operator.lt for SymInt - return op.Less(self, other) + return op.And(other, op.Not(self)) def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: @@ -5596,16 +5340,20 @@ def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, I return result, indices -@torch_op("aten::maximum", trace_only=True) -def aten_maximum(self: TTensor, other: TTensor) -> TTensor: +@torch_op("aten::maximum") +def aten_maximum(self: TReal, other: TReal) -> TReal: """maximum(Tensor self, Tensor other) -> Tensor""" - if self.dtype == ir.DataType.BOOL: - return op.Or(self, other) - return op.Max(self, other) +@torch_op("aten::maximum") +def aten_maximum_bool(self: BOOL, other: BOOL) -> BOOL: + """maximum(Tensor self, Tensor other) -> Tensor""" + + return op.Or(self, other) + + @torch_op("aten::mean") def aten_mean(self: TReal) -> TReal: """mean(Tensor self, *, ScalarType? dtype=None) -> Tensor""" @@ -5638,7 +5386,7 @@ def aten_meshgrid(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::min", trace_only=True) +@torch_op("aten::min") def aten_min(self: TReal) -> TReal: """min(Tensor self) -> Tensor""" @@ -5659,16 +5407,20 @@ def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, T return result, indices -@torch_op("aten::minimum", trace_only=True) -def aten_minimum(self: TTensor, other: TTensor) -> TTensor: +@torch_op("aten::minimum") +def aten_minimum(self: TReal, other: TReal) -> TReal: """minimum(Tensor self, Tensor other) -> Tensor""" - if self.dtype == ir.DataType.BOOL: - return op.And(self, other) - return op.Min(self, other) +@torch_op("aten::minimum") +def aten_minimum_bool(self: BOOL, other: BOOL) -> BOOL: + """minimum(Tensor self, Tensor other) -> Tensor""" + + return op.And(self, other) + + def aten_miopen_batch_norm( input: TensorType, weight: TensorType, @@ -6006,21 +5758,26 @@ def aten_msort(self: TensorType) -> TensorType: @torch_op( - ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), + ("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"), trace_only=True, ) -def aten_mul(self: TTensor, other: TTensor) -> TTensor: +def aten_mul(self: TReal, other: TReal) -> TReal: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" - if self.dtype == ir.DataType.BOOL: - return op.And(self, other) - return op.Mul(self, other) -@torch_op("_operator::mul", trace_only=True) -def operator_mul(self: TTensor, other: TTensor) -> TTensor: - return op.Mul(self, other) +@torch_op( + ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), + trace_only=True, +) +def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: + """ONNX Mul doesn't support Boolean, so use And as an equivalent operator.""" + + # TODO(justinchuby): Handle cases where type reconcilation is not enough, + # since different ONNX operators are used based on different data types. + + return op.And(self, other) @torch_op( @@ -6262,6 +6019,7 @@ def aten_native_batch_norm( return norm, input_mean, input_rstd +@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_training_onnx( input: TFloat, weight: TFloat, @@ -6313,6 +6071,7 @@ def _aten_native_batch_norm_training_onnx( return norm, mean, rstd, running_mean, new_running_var +@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_inference_onnx( input: TFloat, weight: TFloat, @@ -6482,10 +6241,22 @@ def aten_native_group_norm( if bias is None: # Set to 0.0 as default, the shape is Channel size bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) + # Accoding to Torch, return rstd instead of var + norm, mean, rstd = _aten_native_group_norm_onnx(input, weight, bias, group, eps) + return norm, mean, rstd + + +@torch_op("aten::native_group_norm", private=True) +def _aten_native_group_norm_onnx( + input: TFloat, + weight: TFloat, + bias: TFloat, + group: INT64, + eps: float, +) -> Tuple[TFloat, TFloat, TFloat]: # Because onnx.GroupNorm() need size=group for weight and bias # But the torch's aten function's input need size=channel, the size mismatched # So we have to use onnx.InstanceNorm() to simulate - # This implementation should be simplified after opset 21 neg_1 = op.Constant(value_ints=[-1]) # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter group_tensor = op.Reshape(group, neg_1) @@ -6501,7 +6272,7 @@ def aten_native_group_norm( norm = op.Reshape(norm, op.Shape(input), allowzero=True) # Using the input weight and bias to do affine # But need to unsqueeze to the target shape for broading cast easy - input_rank = len(input.shape) + input_rank = Rank(input) axes_unsqueeze = op.Range(1, input_rank - 1, 1) weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) @@ -6522,9 +6293,7 @@ def aten_native_group_norm( sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) # In Pytorch, vstd = 1/(sqrt(var + eps)) var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=False) - eps = op.Constant(value=ir.tensor(eps, dtype=input.dtype)) - one = op.Constant(value=ir.tensor(1.0, dtype=input.dtype)) - rstd = op.Div(one, op.Sqrt(op.Add(var, eps))) + rstd = op.Div(1.0, op.Sqrt(var + eps)) # Get the correct shape [N, group] for mean again mean = op.ReduceMean(input_N_group_neg1, axes, keepdims=False) return norm_result, mean, rstd @@ -6736,7 +6505,16 @@ def aten_norm_except_dim(v: TensorType, pow: int = 2, dim: int = 0) -> TensorTyp raise NotImplementedError() -@torch_op("aten::normal_functional", trace_only=True) +@torch_op( + ( + "aten::normal.Tensor_float", + "aten::normal.Tensor_Tensor", + "aten::normal.float_Tensor", + "aten::normal.float_float", + "aten::normal_functional", + ), + trace_only=True, +) def aten_normal( self: TTensor, mean: float = 0.0, @@ -6765,7 +6543,7 @@ def aten_normal_float_float( return op.Cast(result, to=dtype) -@torch_op("aten::normal.float_Tensor", trace_only=True) +@torch_op("aten::normal.float_Tensor") def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: """normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -6775,7 +6553,7 @@ def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: return op.Add(op.Mul(std, sampled), mean_casted) -@torch_op("aten::normal.Tensor_float", trace_only=True) +@torch_op("aten::normal.Tensor_float") def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: """normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor""" @@ -6784,7 +6562,7 @@ def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: return op.Add(op.Mul(op.CastLike(std, sampled), sampled), mean) -@torch_op("aten::normal.Tensor_Tensor", trace_only=True) +@torch_op("aten::normal.Tensor_Tensor") def aten_normal_tensor_tensor(mean: TFloat, std: TFloat) -> TFloat: """normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -6928,41 +6706,34 @@ def aten_pinverse(self: TensorType, rcond: float = 1e-15) -> TensorType: raise NotImplementedError() -@torch_op("aten::pixel_shuffle", trace_only=True) +@torch_op("aten::pixel_shuffle") def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: """pixel_shuffle(Tensor self, int upscale_factor) -> Tensor""" - if len(self.shape) == 4: - return op.DepthToSpace(self, blocksize=upscale_factor, mode="CRD") - + self_shape = op.Shape(self) + batch_dims = self_shape[:-3] + chw_in_dims = self_shape[-3:] # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) - batch_dims = op.Shape(self, end=-3) - chw_in_dims = op.Shape(self, start=-3) - reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD") - final_dims = op.Shape(depth_to_space, start=1) - output_shape = op.Concat(batch_dims, final_dims, axis=0) + output_shape = op.Concat(batch_dims, op.Shape(depth_to_space)[1:], axis=0) return op.Reshape(depth_to_space, output_shape, allowzero=True) -@torch_op("aten::pixel_unshuffle", trace_only=True) +@torch_op("aten::pixel_unshuffle") def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: """pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor""" - if len(self.shape) == 4: - return op.SpaceToDepth(self, blocksize=downscale_factor) + self_shape = op.Shape(self) + batch_dims = self_shape[:-3] + chw_in_dims = self_shape[-3:] # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) - batch_dims = op.Shape(self, end=-3) - chw_in_dims = op.Shape(self, start=-3) - reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor) - final_dims = op.Shape(space_to_depth, start=1) - output_shape = op.Concat(batch_dims, final_dims, axis=0) + output_shape = op.Concat(batch_dims, op.Shape(space_to_depth)[1:], axis=0) return op.Reshape(space_to_depth, output_shape, allowzero=True) @@ -7493,12 +7264,9 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: @torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True) -def aten_remainder(self: TTensor, other: TTensor) -> TTensor: +def aten_remainder(self: TFloat, other: TFloat) -> TFloat: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" - if self.dtype.is_integer(): - return op.Mod(self, other) - # TODO(justinchuby): Improve fp16 precision by following the logic in # https://github.com/pytorch/pytorch/blob/3a823e46170778cc32783f27596c77d0103084a9/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L264-L277 @@ -7508,9 +7276,12 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor: return op.Sub(self, op.Mul(rounded_quotient, other)) -@torch_op("_operator::mod", trace_only=True) -def operator_mod(self: TTensor, other: TTensor) -> TTensor: - # Modulus operator % on SymInt +@torch_op( + ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True +) +def aten_remainder_int(self: TInt, other: TInt) -> TInt: + """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" + return op.Mod(self, other) @@ -7536,131 +7307,20 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: return op.Tile(self_expanded, repeats) -@torch_op("aten::repeat_interleave.self_int", trace_only=True) -def aten_repeat_interleave_self_int( - self: TensorType, repeats: int, dim: Optional[int] = None -) -> TensorType: - """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor - - The trick is to repeat in one direction orthogonal to reshape. - - .. code-block:: python - - x = torch.tensor([[0, 1, 2], [3, 4, 5]]) - x.repeat_interleave(2, dim=0) - - is equivalent to: - - .. code-block:: python - - x = torch.tensor([[0, 1, 2], [3, 4, 5]]) - x.repeat((1, 2)).reshape((-1, t.shape[1])) - """ - if dim is None: - raise NotImplementedError("No conversion available yet when dim is None.") - - self_rank = len(self.shape) - pos_dim = (dim + self_rank) % self_rank - unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) - if isinstance(repeats, int): - tiles = [1] * (self_rank + 1) - tiles[pos_dim + 1] = repeats - tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype)) - else: - # repeats is a symbolic tensor - tile_repeat = op.Concat( - op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)), - op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))), - op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)), - axis=0, - ) - tiled = op.Expand(unsqueezed, tile_repeat) - if self_rank == 1: - return op.Identity(tiled) - final_shape = op.Concat( - op.Shape(self, start=0, end=dim), - op.Constant(value_ints=[-1]), - op.Shape(self, start=pos_dim + 1), - axis=0, - ) - return op.Reshape(tiled, final_shape) - - -@torch_op("aten::repeat_interleave.Tensor", trace_only=True) -def aten_repeat_interleave_Tensor( - self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None +def aten_repeat_interleave( + repeats: TensorType, output_size: Optional[int] = None ) -> TensorType: - """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor - - When `repeats` is a tensor, each line is multiplied - by a different number. - There are multiple strategies. Here is one. - - .. code-block:: python - - import torch - - x = torch.tensor([[0, 1, 2], [3, 4, 5]]) - times = torch.tensor([2, 3], dtype=torch.int64) - y = x.repeat_interleave(times, dim=0) - print("repeat_interleave") - print(y) + """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor""" - ci = times.cumsum(dim=0) - rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1)) - srows = times.shape[0] - rows.to(torch.int64).sum(axis=0) - indices = srows.reshape((-1, )) - print("decomposed") - print(x[indices, :]) - """ - if repeats is None: - repeats = self - self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1) - if dim is None: - # flatten - self = op.Reshape(self, [-1]) - rank = 1 - else: - rank = len(self.shape) - - if rank > 2: - shape_x0 = op.Shape(self, start=0, end=1) - shape_x = op.Shape(self, start=1) - self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0)) - elif rank == 1: - shape_x = None - self = op.Reshape(self, [-1, 1]) - else: - if rank != 2: - raise NotImplementedError( - f"rank(self)={rank} not implemented for repeat_interleave" - ) - shape_x = None - - ci = op.CumSum(repeats, [0]) - last_ci = op.Gather(ci, [-1]) - trange = op.Range(0, op.Squeeze(last_ci, [0]), 1) - rows = op.Less(trange, op.Unsqueeze(ci, [-1])) - srows = op.Sub( - op.Shape(self, start=0, end=1), - op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]), - ) - indices = op.Reshape(srows, [-1]) - values = op.GatherND(self, op.Unsqueeze(indices, [-1])) - if rank == 2: - return values - # shape_x is None at this stage. - assert shape_x is None # for mypy - return op.Reshape( - values, - op.Concat([-1], shape_x, axis=0) if shape_x else [-1], - ) + raise NotImplementedError() -@torch_op("aten::reshape", trace_only=True) -def aten_reshape(self: TTensor, shape: Sequence[INT64]) -> TTensor: +@torch_op("aten::reshape") +def aten_reshape(self: TTensor, shape: IntType) -> TTensor: """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)""" - shape = common_ops.merge_dims(shape) + + # Reshape only support INT64 as 'shape' + shape = op.Cast(shape, to=INT64.dtype) return op.Reshape(self, shape) @@ -7732,29 +7392,23 @@ def aten_rnn_tanh_cell( def aten_roll(self: TTensor, shifts: Sequence[int], dims: Sequence[int] = ()) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" - if isinstance(shifts, int): - shifts = [shifts] - - if isinstance(dims, int): - dims = [dims] - self_rank = len(self.shape) if self_rank == 0: return op.Identity(self) elif self.shape[0] == 0: # empty tensor return op.Identity(self) - - # NOTE: In pytorch, default value of dims is an empty list. - if len(dims) == 0: # Empty sequence - assert len(shifts) == 1, "shifts should be a single integer if dims is empty" - return _aten_roll_shift_no_dim_onnx(self, shifts[0]) else: - assert len(shifts) == len(dims) - result = self - for i, shift in enumerate(shifts): - dim = dims[i] - result = _aten_roll_shift_and_dim_onnx(result, shift, dim) - return result + # NOTE: In pytorch, default value of dims is an empty list. + if len(dims) == 0: # Empty sequence + # assert isinstance(shifts, int) + return _aten_roll_shift_no_dim_onnx(self, shifts) + else: + # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list + result = self + for i, shift in enumerate(shifts): + dim = dims[i] + result = _aten_roll_shift_and_dim_onnx(result, shift, dim) + return result @torch_op("aten::roll", trace_only=True, complex=True) @@ -7763,12 +7417,6 @@ def aten_roll_complex( ) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" - if isinstance(shifts, int): - shifts = [shifts] - - if isinstance(dims, int): - dims = [dims] - self_rank = len(self.shape) if self_rank == 1: return op.Identity(self) @@ -7779,34 +7427,37 @@ def aten_roll_complex( self_real = op.Slice(self, [0], [1], axes=[-1]) self_imag = op.Slice(self, [1], [2], axes=[-1]) if not dims: - assert len(shifts) == 1, "shifts should be a single integer if dims is empty" - shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts[0]) - shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts[0]) + # assert isinstance(shifts, int) + shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts) + shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts) result = op.Concat(shift_real, shift_imag, axis=-1) else: - assert len(shifts) == len(dims) + # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list for i, dim in enumerate(dims): - self_real = _aten_roll_shift_and_dim_onnx(self_real, shifts[i], dim) - self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shifts[i], dim) + shift = op.Gather(shifts, i, axis=0) + self_real = _aten_roll_shift_and_dim_onnx(self_real, shift, dim) + self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shift, dim) result = op.Concat(self_real, self_imag, axis=-1) return result -def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: int) -> TTensor: +@torch_op("aten::roll", private=True) +def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) # flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D] self_flatten = op.Reshape(self, neg_1) # Compute slice length - if shift < 0: + shift_tensor = op.Reshape(shift, neg_1) + if shift_tensor < 0: # For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end - slice_length = op.Constant(value_ints=[-shift]) + slice_length = -shift_tensor else: # For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end # The effect equals to move [D] to the beginning - slice_length = op.Size(self_flatten) - op.Constant(value_ints=[shift]) + slice_length = op.Size(self_flatten) - shift_tensor # Get second part of the tensor, e.g. [A,B,C] suffix = op.Slice(self_flatten, op.Constant(value_ints=[0]), slice_length) # Get first part of the tensor, e.g. [D] @@ -7816,13 +7467,15 @@ def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: int) -> TTensor: return op.Reshape(result, op.Shape(self)) -def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: int, dim: int) -> TTensor: +@torch_op("aten::roll", private=True) +def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: INT64, dim: int) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) - dim_tensor = op.Constant(value_ints=[dim]) - if shift < 0: - slice_length = op.Constant(value_ints=[-shift]) + dim_tensor = op.Reshape(op.Constant(value_int=dim), neg_1) + shift_tensor = op.Reshape(shift, neg_1) + if shift_tensor < 0: + slice_length = -shift_tensor else: - slice_length = op.Shape(self, start=dim, end=dim + 1) - op.Constant(value_ints=[shift]) + slice_length = op.Gather(op.Shape(self), dim_tensor, axis=0) - shift_tensor # from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor) prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor) @@ -7901,7 +7554,7 @@ def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: @torch_op("aten::scalar_tensor", trace_only=True) def aten_scalar_tensor( - s: TensorType, + s: float, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", @@ -7910,7 +7563,8 @@ def aten_scalar_tensor( """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - + # Set trace_only=True because different if branches return different dtypes + # which is not supported in an ONNX function return common_ops.cast_to(s, dtype=dtype) @@ -7939,35 +7593,31 @@ def aten_scalar_tensor_complex( return result -@torch_op("aten::scatter.src", trace_only=True) -def aten_scatter_src( - self: TTensor, - dim: int, # we have to use int here because ScatterElements() will use this attribute - index: TInt, - src: TTensor, -) -> TTensor: - """scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" - if len(index.shape) == 0: - index = op.Unsqueeze(index, [0]) - if len(src.shape) == 0: - src = op.Unsqueeze(src, [0]) - return op.ScatterElements(self, index, src, axis=dim) +@torch_op("aten::scalar_tensor", trace_only=True) +def aten_scalar_tensor_sym_number( + s: TensorType, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> RealType: + """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + if dtype == -1: + dtype = FLOAT.dtype + return common_ops.cast_to(s, dtype=dtype) -@torch_op("aten::scatter.value", trace_only=True) -def aten_scatter_value( - self: TTensor, +@torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True) +def aten_scatter( + self: TReal, dim: int, # we have to use int here because ScatterElements() will use this attribute index: TInt, - value: float, -) -> TTensor: - """scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor""" - # Ensure value is a scalar tensor and expand it to match index shape - if len(index.shape) == 0: - index = op.Unsqueeze(index, [0]) - scalar_tensor = ir.tensor([value], dtype=self.dtype) - src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor) - return op.ScatterElements(self, index, src, axis=dim) + src: TReal, +) -> TReal: + """scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" + + update = op.Expand(src, op.Shape(index)) + return op.ScatterElements(self, index, update, axis=dim) @torch_op("aten::scatter_add", trace_only=True) @@ -8326,7 +7976,7 @@ def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.Softmax(self, axis=dim) - if dtype != -1 and dtype is not None: + if dtype != -1: result = op.Cast(result, to=dtype) if self_is_scalar: # Convert to scalar when input is scalar @@ -8335,6 +7985,21 @@ def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: return result +@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True) +def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat: + """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" + + self_is_scalar = len(self.shape) == 0 + if self_is_scalar: + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + result = op.Softmax(self, axis=dim) + if self_is_scalar: + # Convert to scalar when input is scalar + result = op.Squeeze(result) + + return result + + @torch_op("aten::sort", trace_only=True) def aten_sort( self: TReal, dim: int = -1, descending: bool = False, stable: bool = False @@ -8551,7 +8216,9 @@ def aten_std_mean_correction( @torch_op( ( "aten::sub.Tensor", + "aten::sub.Scalar", "aten::subtract.Tensor", + "aten::subtract.Scalar", "_operator::sub", ), trace_only=True, @@ -8564,14 +8231,6 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: return op.Sub(self, other) -@torch_op(("aten::sub.Scalar", "aten::subtract.Scalar"), trace_only=True) -def aten_sub_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor: - """sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor""" - - other = op.Constant(value=ir.tensor(other, dtype=self.dtype)) - return aten_sub(self, other, alpha=alpha) - - @torch_op( ( "aten::sub.Tensor", @@ -8664,7 +8323,7 @@ def aten_symeig( def aten_t(self: TTensor) -> TTensor: """t(Tensor(a) self) -> Tensor(a)""" - rank = len(self.shape) + rank = Rank(self) if rank == 2: result = op.Transpose(self, perm=[1, 0]) else: @@ -8747,24 +8406,26 @@ def aten_threshold_backward( raise NotImplementedError() -@torch_op("aten::tile", trace_only=True) -def aten_tile(self: TTensor, dims: Sequence[int]) -> TTensor: +@torch_op("aten::tile") +def aten_tile(self: TTensor, dims: INT64) -> TTensor: """tile(Tensor self, int[] dims) -> Tensor""" - self_rank = len(self.shape) - dims_rank = len(dims) - diff = self_rank - dims_rank + self_rank = Rank(self) + dims_rank = op.Size(dims) + diff = op.Sub(self_rank, dims_rank) if diff > 0: # dims is shorter than self.shape # pad dims with 1 - exapnd_ones = [1] * diff - dims = [*exapnd_ones, *dims] + diff_1d = op.Reshape(diff, op.Constant(value_ints=[1])) + exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d) + dims = op.Concat(exapnd_ones, dims, axis=0) - elif diff < 0: + if diff < 0: # dims is longer than self.shape # pad self.shape with 1 - exapnd_ones = op.Constant(value_ints=[1] * (-diff)) + diff_1d = op.Reshape(op.Abs(diff), op.Constant(value_ints=[1])) + exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d) self_shape = op.Shape(self) self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0) self = op.Reshape(self, self_final_shape, allowzero=True) @@ -8915,7 +8576,7 @@ def aten_triangular_solve( raise NotImplementedError() -@torch_op("aten::tril", trace_only=True) +@torch_op("aten::tril") def aten_tril(self: TTensor, diagonal: int = 0) -> TTensor: """tril(Tensor self, int diagonal=0) -> Tensor""" @@ -8943,7 +8604,7 @@ def aten_triplet_margin_loss( raise NotImplementedError() -@torch_op("aten::triu", trace_only=True) +@torch_op("aten::triu") def aten_triu(self: TTensor, diagonal: int = 0) -> TTensor: """triu(Tensor self, int diagonal=0) -> Tensor""" @@ -8963,14 +8624,6 @@ def aten_trunc(self: TFloat) -> TFloat: return op.Floor(op.Abs(self)) * op.Sign(self) -@torch_op("math::trunc", trace_only=True) -def python_math_trunc(self: TFloat) -> TInt: - """trunc(Tensor self) -> Tensor""" - # NOTE: This is used in SymInt/SymBool/SymFloat context, so - # we don't expect overflow to happen here. - return op.Cast(self, to=INT64.dtype) - - @torch_op("aten::type_as", trace_only=True) def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: """type_as(Tensor self, Tensor other) -> Tensor""" @@ -8978,22 +8631,12 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: return op.CastLike(self, other) -@torch_op("aten::unbind.int", trace_only=True) +@torch_op("aten::unbind.int") def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" - if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"): - # We can create a definitive split op if the input shape is static - # Only torch>=2.7 supports correctly generating the correct number of outputs for Split - num_outputs = self.shape[dim] - if num_outputs != 1: - outputs = op.Split(self, axis=dim, num_outputs=num_outputs) - else: - outputs = [self] - - return [op.Squeeze(out, [dim]) for out in outputs] - - return op.SplitToSequence(self, axis=dim, keepdims=False) + split_sizes = op.Constant(value_int=1) + return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False) @torch_op("aten::unflatten.int", trace_only=True) @@ -9417,22 +9060,23 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: @torch_op(("aten::view", "aten::_unsafe_view"), trace_only=True) -def aten_view(self: TTensor, size: Sequence[INT64]) -> TTensor: +def aten_view(self: TTensor, size: IntType) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = common_ops.merge_dims(size) + size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input return op.Reshape(self, size, allowzero=True) -@torch_op(("aten::view", "aten::_unsafe_view"), complex=True, trace_only=True) -def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor: +@torch_op(("aten::view", "aten::_unsafe_view"), complex=True) +def aten_view_complex(self: TTensor, size: IntType) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - complex_size = common_ops.merge_dims([*size, 2]) + size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0) return op.Reshape(self, complex_size, allowzero=True) -@torch_op("aten::view_as", trace_only=True) +@torch_op("aten::view_as") def aten_view_as(self: TTensor, other: TTensor2) -> TTensor: """view_as(Tensor(a) self, Tensor other) -> Tensor(a)""" @@ -9476,11 +9120,11 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor: return op.Identity(self) -@torch_op("aten::view_copy", trace_only=True) -def aten_view_copy(self: TTensor, size: Sequence[INT64]) -> TTensor: +@torch_op("aten::view_copy") +def aten_view_copy(self: TTensor, size: IntType) -> TTensor: """view_copy(Tensor self, SymInt[] size) -> Tensor""" - size = common_ops.merge_dims(size) + size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input return op.Reshape(self, size) @@ -9508,8 +9152,7 @@ def reshape_to_2d(tensor): "aten::where.ScalarSelf", "aten::where.ScalarOther", "aten::where.self", - ), - trace_only=True, + ) ) def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor: """where.self(Tensor condition, Tensor self, Tensor other) -> Tensor""" @@ -9525,7 +9168,7 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::zeros", trace_only=True) def aten_zeros( - size: Sequence[INT64], + size: IntType, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", @@ -9534,9 +9177,9 @@ def aten_zeros( """zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - - zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) - size = common_ops.merge_dims(size) + size = op.Cast(size, to=INT64.dtype) + zero = op.Constant(value_float=0.0) + zero = op.Cast(zero, to=dtype) return op.Expand(zero, size) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 5d7deb169..3d8189618 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -37,37 +37,6 @@ def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs) yield opinfo_core.SampleInput(item, dtype=dtype) -def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs): - """Sample inputs for bilinear operation.""" - del op_info - del kwargs - - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad - ) - - # Test cases: (batch_size, in1_features, in2_features, out_features) - cases = [ - (2, 3, 4, 5), # Basic case - (1, 2, 2, 1), # Minimal case - (3, 5, 7, 4), # Different dimensions - (2, 1, 1, 3), # Single input features - ] - - for batch_size, in1_features, in2_features, out_features in cases: - input1 = make_arg((batch_size, in1_features)) - input2 = make_arg((batch_size, in2_features)) - weight = make_arg((out_features, in1_features, in2_features)) - bias = make_arg((out_features,)) - - # Test with bias - yield opinfo_core.SampleInput(input1, args=(input2, weight, bias)) - - # Test without bias (only for first case to avoid too many tests) - if batch_size == 2: - yield opinfo_core.SampleInput(input1, args=(input2, weight, None)) - - def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs): del op_info @@ -118,35 +87,6 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra yield opinfo_core.SampleInput(t, kwargs={"p": p}) -def sample_inputs_broadcast_in_dim(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del kwargs - - # cases: (input_shape, target_shape, broadcast_dimensions) - # broadcast_dimensions maps each input dim to an axis in target_shape - cases = ( - # scalar -> 1-D tensor - ((), (3,), ()), - # identity (no-op broadcast) - ((3,), (3,), (0,)), - # rank-preserving broadcast where singleton dims expand - ((1, 3, 1), (2, 3, 4), (0, 1, 2)), - # input rank 2 -> output rank 3, input dims map to trailing axes - ((3, 1), (2, 3, 4), (1, 2)), - # add leading broadcast axis - ((3, 4), (1, 3, 4), (1, 2)), - # insert broadcasting in middle axis - ((3,), (2, 3, 1), (1,)), - ) - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad - ) - - for shape, target_shape, broadcast_dimensions in cases: - tensor = make_arg(shape) - yield opinfo_core.SampleInput(tensor, args=(target_shape, broadcast_dimensions)) - - def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): del op_info # input_shape, output_size, kernal, dilation, padding, stride @@ -1396,109 +1336,6 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs) yield opinfo_core.SampleInput(input_, args=(src, *args)) -def sample_inputs_scatter_src(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del kwargs - make_arg = functools.partial( - torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad - ) - - # Basic test cases for scatter.src - cases = [ - # (self_shape, index_shape, src_shape, dim) - ((5, 5), (2, 3), (2, 3), 0), # 2D scatter on dim=0 - ((5, 5), (3, 2), (3, 2), 1), # 2D scatter on dim=1 - ((3, 4, 5), (2, 2, 3), (2, 2, 3), 0), # 3D scatter on dim=0 - ((3, 4, 5), (2, 2, 3), (2, 2, 3), 1), # 3D scatter on dim=1 - ((3, 4, 5), (2, 2, 3), (2, 2, 3), 2), # 3D scatter on dim=2 - ((10,), (3,), (3,), 0), # 1D scatter - ] - - for self_shape, index_shape, src_shape, dim in cases: - self_tensor = make_arg(self_shape) - # Create valid indices for the given dimension without duplication - index_buffer_shape = list(index_shape) - index_buffer_shape[dim] = self_shape[dim] - index_tensor = torch.rand(index_buffer_shape, device=device).argsort(dim=dim)[ - tuple(slice(None, d, None) for d in index_shape) - ] - src_tensor = make_arg(src_shape) - yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, src_tensor)) - - # Additional test cases for scalar and single-element tensor combinations with dim=0 - # Test case: scalar index, scalar src (dim_size=5) - dim_size = 5 - data_1d = make_arg((dim_size,)) - valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) - scalar_src = make_arg(()) - yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, scalar_src)) - - # Test case: single-element tensor index, scalar src (dim_size=7) - dim_size = 7 - data_1d = make_arg((dim_size,)) - valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) - scalar_src = make_arg(()) - yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, scalar_src)) - - # Test case: scalar index, single-element tensor src (dim_size=3) - dim_size = 3 - data_1d = make_arg((dim_size,)) - valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) - src_1d = make_arg((1,)) - yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, src_1d)) - - # Test case: single-element tensor index, single-element tensor src (dim_size=10) - dim_size = 10 - data_1d = make_arg((dim_size,)) - valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) - src_1d = make_arg((1,)) - yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, src_1d)) - - -def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del kwargs - make_arg = functools.partial( - torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad - ) - - # Basic test cases for scatter.value - cases = [ - # (self_shape, index_shape, dim, value) - ((5, 5), (2, 3), 0, 1.0), # 2D scatter on dim=0 with scalar value - ((5, 5), (3, 2), 1, -2.5), # 2D scatter on dim=1 with scalar value - ((3, 4, 5), (2, 2, 3), 0, False), # 3D scatter on dim=0 with scalar value - ((3, 4, 5), (2, 2, 3), 1, 3.14), # 3D scatter on dim=1 with scalar value - ((3, 4, 5), (2, 2, 3), 2, -1), # 3D scatter on dim=2 with scalar value - ((10,), (3,), 0, 5.0), # 1D scatter with scalar value - ] - - for self_shape, index_shape, dim, value in cases: - self_tensor = make_arg(self_shape) - # Create valid indices for the given dimension without duplication - index_buffer_shape = list(index_shape) - index_buffer_shape[dim] = self_shape[dim] - index_tensor = torch.rand(index_buffer_shape, device=device).argsort(dim=dim)[ - tuple(slice(None, d, None) for d in index_shape) - ] - yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, value)) - - # Additional test cases for scalar and single-element tensor combinations with dim=0 - # Test case: scalar index with scalar value (dim_size=6, value_type=torch.long) - dim_size = 6 - data_1d = make_arg((dim_size,)) - valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) - random_value = torch.randint(0, 10, (), device=device, dtype=torch.long).item() - yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, random_value)) - - # Test case: single-element tensor index with scalar value (dim_size=8, value_type=torch.float) - dim_size = 8 - data_1d = make_arg((dim_size,)) - valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) - random_value = torch.rand((), device=device, dtype=torch.float).item() - yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, random_value)) - - def sample_inputs__scaled_dot_product_flash_attention( op_info, device, dtype, requires_grad, **kwargs ): @@ -2314,13 +2151,6 @@ def __init__(self): # To avoid name duplication, it is possible to rename the OpInfo and specify # the `op` field explicitly. OP_DB: List[opinfo_core.OpInfo] = [ - opinfo_core.OpInfo( - "bilinear", - op=torch.nn.functional.bilinear, - dtypes=common_dtype.floating_types(), - sample_inputs_func=sample_inputs_bilinear, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.bernoulli.p", aten_name="bernoulli.p", @@ -2380,6 +2210,44 @@ def __init__(self): sample_inputs_func=sample_inputs_embedding_bag_padding_idx, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.embedding_bag.padding_idx_none", + op=torch.nn.functional.embedding_bag, + dtypes=common_dtype.floating_types_and_half(), + sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ + opinfo_core.SampleInput( + torch.tensor( + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]], + dtype=dtype, + device=device, + ), + args=( + torch.tensor([0, 1, 2, 3], dtype=torch.int64, device=device), + torch.tensor([0, 2], dtype=torch.int64, device=device), + ), + kwargs={"padding_idx": None}, + ) + ], + ), + opinfo_core.OpInfo( + "ops.aten.embedding_bag.padding_idx_int", + op=torch.nn.functional.embedding_bag, + dtypes=common_dtype.floating_types_and_half(), + sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ + opinfo_core.SampleInput( + torch.tensor( + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], + dtype=dtype, + device=device, + ), + args=( + torch.tensor([0, 1, 2], dtype=torch.int64, device=device), + torch.tensor([0, 2], dtype=torch.int64, device=device), + ), + kwargs={"padding_idx": 0}, + ) + ], + ), opinfo_core.OpInfo( "ops.aten.embedding_renorm", aten_name="embedding_renorm", @@ -2411,9 +2279,18 @@ def __init__(self): opinfo_core.BinaryUfuncInfo( "ops.aten.floor_divide", aten_name="floor_divide", - dtypes=common_dtype.all_types_and_half(), + dtypes=common_dtype.floating_types_and_half(), rhs_make_tensor_kwargs=dict(exclude_zero=True), ), + opinfo_core.BinaryUfuncInfo( + "ops.aten.floor_divide.int", + aten_name="floor_divide", + op=torch.ops.aten.floor_divide, + dtypes=common_dtype.integral_types(), + # Create only positive inputs + lhs_make_tensor_kwargs=dict(low=0), + rhs_make_tensor_kwargs=dict(exclude_zero=True, low=0), + ), opinfo_core.OpInfo( "ops.aten.hamming_window", aten_name="hamming_window", @@ -2674,22 +2551,6 @@ def __init__(self): sample_inputs_func=sample_inputs_slice_scatter, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.scatter.src", - op=torch.ops.aten.scatter.src, - aten_name="scatter.src", - dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), - sample_inputs_func=sample_inputs_scatter_src, - supports_out=False, - ), - opinfo_core.OpInfo( - "ops.aten.scatter.value", - op=torch.ops.aten.scatter.value, - aten_name="scatter.value", - dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), - sample_inputs_func=sample_inputs_scatter_value, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten._softmax", op=torch.ops.aten._softmax, # pylint: disable=protected-access @@ -2864,13 +2725,6 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_trilinear3d_vec, supports_out=False, ), - opinfo_core.ReductionOpInfo( - "ops.prims.broadcast_in_dim.default", - op=torch.ops.prims.broadcast_in_dim.default, - dtypes=common_dtype.all_types(), - sample_inputs_func=sample_inputs_broadcast_in_dim, - supports_out=False, - ), opinfo_core.ReductionOpInfo( "ops.prims.var.default", nan_policy="propagate", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b60fd8cf3..a9ed4b843 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -48,6 +48,7 @@ from torch.testing._internal.opinfo import definitions as opinfo_definitions from typing_extensions import Self +from onnxscript._internal import version_utils from onnxscript.function_libs.torch_lib import _flags from onnxscript.function_libs.torch_lib.ops import core as core_ops from onnxscript.function_libs.torch_lib.ops import fft as fft_ops @@ -184,6 +185,25 @@ def xfail( # Modify this section ########################################################## +def _embedding_bag_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # ONNX attributes cannot be None; omit padding_idx if it's None. + if "padding_idx" in kwargs: + padding_idx = kwargs.pop("padding_idx") + if padding_idx is not None: + kwargs["padding_idx"] = int(padding_idx) + + # Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...) + if len(args) >= 3: + if isinstance(args[1], torch.Tensor): + args[1] = args[1].to(torch.long) + if isinstance(args[2], torch.Tensor): + args[2] = args[2].to(torch.long) + + return args, kwargs + + def _amin_amax_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -458,13 +478,40 @@ def _where_input_wrangler( fft_ops.aten__fft_r2c, tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)}, ), - TorchLibOpInfo("ops.aten._local_scalar_dense", core_ops.aten__local_scalar_dense), TorchLibOpInfo( - "ops.aten._log_softmax", - core_ops.aten__log_softmax, + "ops.aten._local_scalar_dense", + core_ops.aten__local_scalar_dense, + ), + TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax), + TorchLibOpInfo( + "ops.aten._log_softmax_half", + core_ops.aten__log_softmax_half, tolerance={torch.float16: (1e-3, 1e-3)}, + ) + .xfail( + reason="PyTorch does not implement _log_softmax for float16 on CPU", + dtypes=(torch.float16,), + enabled_if=version_utils.torch_older_than("2.2"), + ) + .xfail( + enabled_if=version_utils.onnxruntime_older_than("1.17"), + dtypes=(torch.float16,), + reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", + test_class_name="TestOutputConsistencyFullGraph", ), TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax), + TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half) + .xfail( + reason="PyTorch does not implement _softmax for float16 on CPU", + dtypes=(torch.float16,), + enabled_if=version_utils.torch_older_than("2.2"), + ) + .xfail( + enabled_if=version_utils.onnxruntime_older_than("1.17"), + dtypes=(torch.float16,), + reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", + test_class_name="TestOutputConsistencyFullGraph", + ), TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), @@ -475,7 +522,10 @@ def _where_input_wrangler( reason="this overload requires dim to be a tuple", ), TorchLibOpInfo("allclose", core_ops.aten_allclose), - TorchLibOpInfo("all", core_ops.aten_all).skip( + TorchLibOpInfo( + "all", + core_ops.aten_all, + ).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), @@ -510,14 +560,32 @@ def _where_input_wrangler( reason="zero sized inputs cannot be compared", ), TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (2e-3, 2e-2)}), - TorchLibOpInfo("addr", core_ops.aten_addr, tolerance={torch.float16: (3e-3, 4e-3)}), - TorchLibOpInfo("amax", core_ops.aten_amax, input_wrangler=_amin_amax_input_wrangler), - TorchLibOpInfo("amin", core_ops.aten_amin, input_wrangler=_amin_amax_input_wrangler), - TorchLibOpInfo("any", core_ops.aten_any).skip( + TorchLibOpInfo( + "addr", + core_ops.aten_addr, + tolerance={torch.float16: (3e-3, 4e-3)}, + ), + TorchLibOpInfo( + "amax", + core_ops.aten_amax, + input_wrangler=_amin_amax_input_wrangler, + ), + TorchLibOpInfo( + "amin", + core_ops.aten_amin, + input_wrangler=_amin_amax_input_wrangler, + ), + TorchLibOpInfo( + "any", + core_ops.aten_any, + ).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), - TorchLibOpInfo("any_dim", core_ops.aten_any_dim).skip( + TorchLibOpInfo( + "any_dim", + core_ops.aten_any_dim, + ).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer", @@ -529,58 +597,85 @@ def _where_input_wrangler( TorchLibOpInfo("asin", core_ops.aten_asin), TorchLibOpInfo("asinh", core_ops.aten_asinh), TorchLibOpInfo("atan", core_ops.aten_atan), - TorchLibOpInfo("atan2", core_ops.aten_atan2), + TorchLibOpInfo("atan2", core_ops.aten_atan2, tolerance={torch.float16: (1e-3, 1e-3)}), TorchLibOpInfo("atanh", core_ops.aten_atanh), TorchLibOpInfo("atleast_1d", core_ops.aten_atleast_1d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo("atleast_1d_Sequence", core_ops.aten_atleast_1d_sequence) + TorchLibOpInfo( + "atleast_1d_Sequence", + core_ops.aten_atleast_1d_sequence, + ) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) + .xfail( + enabled_if=version_utils.onnxruntime_older_than("1.16"), + reason=( + "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." + "https://github.com/microsoft/onnxscript/issues/960" + ), + ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ) + ), ), TorchLibOpInfo("atleast_2d", core_ops.aten_atleast_2d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo("atleast_2d_Sequence", core_ops.aten_atleast_2d_sequence) + TorchLibOpInfo( + "atleast_2d_Sequence", + core_ops.aten_atleast_2d_sequence, + ) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) + .xfail( + enabled_if=version_utils.onnxruntime_older_than("1.16"), + reason=( + "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." + "https://github.com/microsoft/onnxscript/issues/960" + ), + ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ) + ), ), TorchLibOpInfo("atleast_3d", core_ops.aten_atleast_3d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo("atleast_3d_Sequence", core_ops.aten_atleast_3d_sequence) + TorchLibOpInfo( + "atleast_3d_Sequence", + core_ops.aten_atleast_3d_sequence, + ) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) + .xfail( + enabled_if=version_utils.onnxruntime_older_than("1.16"), + reason=( + "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." + "https://github.com/microsoft/onnxscript/issues/960" + ), + ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ) + ), ), TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}), TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True), - TorchLibOpInfo( - "bilinear", core_ops.aten_bilinear, tolerance={torch.float32: (2e-5, 2e-5)} - ), TorchLibOpInfo( # This string is a unique ID. In extra_opinfo.py, we # also define test data for this ID with @@ -592,10 +687,16 @@ def _where_input_wrangler( ), TorchLibOpInfo("ops.aten.bernoulli.p_deterministic", core_ops.aten_bernoulli_p), TorchLibOpInfo("bitwise_and", core_ops.aten_bitwise_and), - TorchLibOpInfo("bitwise_left_shift", core_ops.aten_bitwise_left_shift), + TorchLibOpInfo("bitwise_left_shift_int16", core_ops.aten_bitwise_left_shift_int16), + TorchLibOpInfo("bitwise_left_shift_int32", core_ops.aten_bitwise_left_shift_int32), + TorchLibOpInfo("bitwise_left_shift_int64", core_ops.aten_bitwise_left_shift_int64), + TorchLibOpInfo("bitwise_left_shift_int8", core_ops.aten_bitwise_left_shift_int8), TorchLibOpInfo("bitwise_not", core_ops.aten_bitwise_not), TorchLibOpInfo("bitwise_or", core_ops.aten_bitwise_or), - TorchLibOpInfo("bitwise_right_shift", core_ops.aten_bitwise_right_shift), + TorchLibOpInfo("bitwise_right_shift_int16", core_ops.aten_bitwise_right_shift_int16), + TorchLibOpInfo("bitwise_right_shift_int32", core_ops.aten_bitwise_right_shift_int32), + TorchLibOpInfo("bitwise_right_shift_int64", core_ops.aten_bitwise_right_shift_int64), + TorchLibOpInfo("bitwise_right_shift_int8", core_ops.aten_bitwise_right_shift_int8), TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor), TorchLibOpInfo("ops.aten.blackman_window", core_ops.aten_blackman_window), TorchLibOpInfo("bmm", core_ops.aten_bmm), @@ -613,7 +714,10 @@ def _where_input_wrangler( reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), TorchLibOpInfo("ceil", core_ops.aten_ceil), - TorchLibOpInfo("chunk", core_ops.aten_chunk), + TorchLibOpInfo("chunk", core_ops.aten_chunk).skip( + enabled_if=version_utils.torch_older_than("2.7"), + reason="Test for chunk is not configured for torch<2.7", + ), TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip( reason="Size 0 inputs are not handled by design", matcher=lambda sample: sample.input.numel() == 0, @@ -649,6 +753,7 @@ def _where_input_wrangler( TorchLibOpInfo("deg2rad", core_ops.aten_deg2rad), # TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB TorchLibOpInfo("diagonal", core_ops.aten_diagonal), + TorchLibOpInfo("diagonal_bool", core_ops.aten_diagonal_bool), TorchLibOpInfo("div", core_ops.aten_div).skip( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="this variation does not take the rounding_mode argument", @@ -666,6 +771,7 @@ def _where_input_wrangler( # Numbers match sometimes but not other times reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", ), + TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int), TorchLibOpInfo("dot", core_ops.aten_dot), TorchLibOpInfo( "empty", @@ -675,7 +781,8 @@ def _where_input_wrangler( ), TorchLibOpInfo("einsum", core_ops.aten_einsum, input_wrangler=_einsum_input_wrangler) .xfail( - reason="fixme: PyTorch produces int64 output with int32 input", dtypes=(torch.int32,) + reason="fixme: PyTorch produces int64 output with int32 input", + dtypes=(torch.int32,), ) .xfail( reason="fixme: ONNX shape inference fails: https://github.com/onnx/onnx/issues/5739", @@ -706,18 +813,25 @@ def _where_input_wrangler( TorchLibOpInfo("flatten", core_ops.aten_flatten), TorchLibOpInfo("floor", core_ops.aten_floor), TorchLibOpInfo("ops.aten.floor_divide", core_ops.aten_floor_divide), + TorchLibOpInfo("ops.aten.floor_divide.int", core_ops.aten_floor_divide_int), TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("frac", core_ops.aten_frac), TorchLibOpInfo("full", core_ops.aten_full), - TorchLibOpInfo("full_like", core_ops.aten_full_like).skip( - enabled_if=ops_test_common.IS_MACOS, reason="fixme: memory allocation issue on CI" + TorchLibOpInfo( + "full_like", + core_ops.aten_full_like, + ).skip( + enabled_if=ops_test_common.IS_MACOS, + reason="fixme: memory allocation issue on CI", ), TorchLibOpInfo("gather", core_ops.aten_gather).skip( matcher=lambda sample: sample.input.numel() == 0 or sample.args[1].numel() == 0, reason="fixme: ORT does not support empty tensors as input", ), TorchLibOpInfo("ge", core_ops.aten_ge), + TorchLibOpInfo("ge_bool", core_ops.aten_ge_bool), TorchLibOpInfo("gt", core_ops.aten_gt), + TorchLibOpInfo("gt_bool", core_ops.aten_gt_bool), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), @@ -731,7 +845,9 @@ def _where_input_wrangler( reason="this Aten overload only supports tensor(bool) as indices", ), TorchLibOpInfo( - "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler + "index_put", + core_ops.aten_index_put, + input_wrangler=_index_put_input_wrangler, ) .skip( matcher=lambda sample: sample.args[0][0].dtype != torch.int64, @@ -771,13 +887,20 @@ def _where_input_wrangler( dtypes=(torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", ) - .skip( - matcher=lambda sample: sample.kwargs.get("dtype") in (torch.int64, torch.int32), + .xfail( + variant_name="tensor_overload", + dtypes=(torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", + enabled_if=not version_utils.torch_older_than("2.2"), ), TorchLibOpInfo("log", core_ops.aten_log), TorchLibOpInfo("le", core_ops.aten_le), - TorchLibOpInfo("lerp", core_ops.aten_lerp, tolerance={torch.float16: (2e-3, 2e-1)}), + TorchLibOpInfo("le_bool", core_ops.aten_le_bool), + TorchLibOpInfo( + "lerp", + core_ops.aten_lerp, + tolerance={torch.float16: (2e-3, 2e-1)}, + ), TorchLibOpInfo("log10", core_ops.aten_log10), TorchLibOpInfo("log1p", core_ops.aten_log1p), TorchLibOpInfo( @@ -816,6 +939,7 @@ def _where_input_wrangler( TorchLibOpInfo("logdet", core_ops.aten_logdet), TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp), TorchLibOpInfo("lt", core_ops.aten_lt), + TorchLibOpInfo("lt_bool", core_ops.aten_lt_bool), TorchLibOpInfo("masked_fill", core_ops.aten_masked_fill).xfail( dtypes=(torch.bool,), reason="fixme: ORT does not have an implementation for Where with bool inputs.", @@ -831,12 +955,19 @@ def _where_input_wrangler( reason="values of matmul of [m, 0] and [0, n] matrices are undefined", ), TorchLibOpInfo("maximum", core_ops.aten_maximum), - TorchLibOpInfo("mean", core_ops.aten_mean, input_wrangler=_mean_input_wrangler).skip( + TorchLibOpInfo("maximum_bool", core_ops.aten_maximum_bool), + TorchLibOpInfo( + "mean", + core_ops.aten_mean, + input_wrangler=_mean_input_wrangler, + ).skip( matcher=lambda sample: sample.kwargs.get("dim") is not None, reason="this Aten overload only accept 1 inputs: self", ), TorchLibOpInfo( - "mean_dim", core_ops.aten_mean_dim, input_wrangler=_mean_input_wrangler + "mean_dim", + core_ops.aten_mean_dim, + input_wrangler=_mean_input_wrangler, ).skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="this Aten overload can accept 2 inputs:(self, dim)", @@ -848,11 +979,15 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo("min", core_ops.aten_min).skip( + TorchLibOpInfo( + "min", + core_ops.aten_min, + ).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), TorchLibOpInfo("minimum", core_ops.aten_minimum), + TorchLibOpInfo("minimum_bool", core_ops.aten_minimum_bool), TorchLibOpInfo("mm", core_ops.aten_mm).skip( matcher=lambda sample: torch.numel(sample.input) == 0, reason="values of matmul of [m, 0] and [0, n] matrices are undefined", @@ -861,19 +996,39 @@ def _where_input_wrangler( TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True), TorchLibOpInfo("mul", core_ops.aten_mul), TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True), - TorchLibOpInfo("mv", core_ops.aten_mv, tolerance={torch.float16: (3e-2, 1e-2)}), + TorchLibOpInfo( + "mv", + core_ops.aten_mv, + tolerance={torch.float16: (3e-2, 1e-2)}, + ), TorchLibOpInfo("narrow", core_ops.aten_narrow), TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout), TorchLibOpInfo("ne", core_ops.aten_ne), TorchLibOpInfo("neg", core_ops.aten_neg), - TorchLibOpInfo("new_empty", core_ops.aten_new_empty, nondeterministic=True), TorchLibOpInfo( - "new_empty_strided", core_ops.aten_new_empty_strided, nondeterministic=True + "new_empty", + core_ops.aten_new_empty, + nondeterministic=True, + ), + TorchLibOpInfo( + "new_empty_strided", + core_ops.aten_new_empty_strided, + nondeterministic=True, + ), + TorchLibOpInfo( + "new_full", + core_ops.aten_new_full, + ), + TorchLibOpInfo( + "new_ones", + core_ops.aten_new_ones, + ), + TorchLibOpInfo( + "new_zeros", + core_ops.aten_new_zeros, ), - TorchLibOpInfo("new_full", core_ops.aten_new_full), - TorchLibOpInfo("new_ones", core_ops.aten_new_ones), - TorchLibOpInfo("new_zeros", core_ops.aten_new_zeros), TorchLibOpInfo("nn.functional.celu", nn_ops.aten_celu), + TorchLibOpInfo("nn.functional.celu_type_promoted", nn_ops.aten_celu_type_promoted), TorchLibOpInfo( "nn.functional.cross_entropy", # use cross_entropy as test case instead of cross_entropy_loss (not in OPS_DB) @@ -886,7 +1041,9 @@ def _where_input_wrangler( reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[target] as int type", ), TorchLibOpInfo( - "nn.functional.dropout", core_ops.aten_dropout, input_wrangler=_dropout_input_wrangler + "nn.functional.dropout", + core_ops.aten_dropout, + input_wrangler=_dropout_input_wrangler, ).skip( matcher=lambda sample: len(sample.kwargs) == 0 or sample.kwargs.get("p", 0.0) > 0.0, reason="dropout is random so the result not match", @@ -897,12 +1054,27 @@ def _where_input_wrangler( core_ops.aten_embedding_bag, tolerance={torch.float32: (1e-4, 5e-4)}, compare_shape_only_for_output=(1, 2, 3), - ).skip(dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly."), + input_wrangler=_embedding_bag_input_wrangler, + ).skip( + dtypes=(torch.float16,), + reason="fixme: results mismatch in torch nightly.", + ), + TorchLibOpInfo( + "ops.aten.embedding_bag.padding_idx_none", + core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, + ), + TorchLibOpInfo( + "ops.aten.embedding_bag.padding_idx_int", + core_ops.aten_embedding_bag_padding_idx, + input_wrangler=_embedding_bag_input_wrangler, + ), TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, tolerance={torch.float16: (1e-2, 1e-2)}, compare_shape_only_for_output=(1, 2, 3), + input_wrangler=_embedding_bag_input_wrangler, ), TorchLibOpInfo( "ops.aten.embedding_renorm", @@ -932,7 +1104,10 @@ def _where_input_wrangler( tolerance={torch.float16: (5e-2, 1e-2)}, ), TorchLibOpInfo("nn.functional.pad", nn_ops.aten_pad) - .skip(variant_name="circular", reason="fixme: ORT does not support the circular mode") + .skip( + variant_name="circular", + reason="fixme: ORT does not support the circular mode", + ) .skip( variant_name="replicate_negative", reason="fixme: The implementation for negative paddings is not correct", @@ -940,21 +1115,34 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.pixel_shuffle", core_ops.aten_pixel_shuffle, - ).xfail( + ) + .xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", + ) + .xfail( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: ORT does not support empty tensor as input", ), TorchLibOpInfo( "nn.functional.pixel_unshuffle", core_ops.aten_pixel_unshuffle, - ).xfail( + ) + .xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", + ) + .xfail( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: ORT does not support empty tensor as input", ), TorchLibOpInfo( "ops.aten.reflection_pad1d", nn_ops.aten_reflection_pad1d, - ).xfail(dtypes=(torch.int64,), reason="Torch not implement reflection_pad1d for int64."), + ).xfail( + dtypes=(torch.int64,), + reason="Torch not implement reflection_pad1d for int64.", + ), TorchLibOpInfo( "nn.functional.reflection_pad2d", nn_ops.aten_reflection_pad2d, @@ -963,9 +1151,26 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "reflect"), reason="this Aten overload need args[1] == 'reflect' for pad mode", ), - TorchLibOpInfo("nn.functional.relu", nn_ops.aten_relu), - TorchLibOpInfo("nn.functional.relu6", nn_ops.aten_relu6), - TorchLibOpInfo("ops.aten.replication_pad1d", nn_ops.aten_replication_pad1d), + TorchLibOpInfo( + "nn.functional.relu", + nn_ops.aten_relu, + ).xfail( + dtypes=(torch.int64,), + enabled_if=version_utils.onnxruntime_older_than("1.17"), + reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", + ), + TorchLibOpInfo( + "nn.functional.relu6", + nn_ops.aten_relu6, + ).xfail( + dtypes=(torch.int64,), + enabled_if=version_utils.onnxruntime_older_than("1.17"), + reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", + ), + TorchLibOpInfo( + "ops.aten.replication_pad1d", + nn_ops.aten_replication_pad1d, + ), TorchLibOpInfo( "nn.functional.replication_pad2d", nn_ops.aten_replication_pad2d, @@ -975,9 +1180,10 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "replicate"), reason="this Aten overload need args[1] == 'replicate' for pad mode", ) - .skip( + .xfail( variant_name="replicate_negative", - reason="fixme: The implementation for negative paddings is not correct. Potentially an ORT issue", + enabled_if=not version_utils.torch_older_than("2.2"), + reason="fixme: negative padding is not implemented yet", ), TorchLibOpInfo( "nn.functional.replication_pad3d", @@ -993,9 +1199,15 @@ def _where_input_wrangler( ), TorchLibOpInfo("nn.functional.selu", core_ops.aten_selu), TorchLibOpInfo( - "nn.functional.mse_loss", nn_ops.aten_mse_loss, input_wrangler=_mse_loss_input_wrangler + "nn.functional.mse_loss", + nn_ops.aten_mse_loss, + input_wrangler=_mse_loss_input_wrangler, ), - TorchLibOpInfo("nonzero", core_ops.aten_nonzero, input_wrangler=_nonzero_input_wrangler) + TorchLibOpInfo( + "nonzero", + core_ops.aten_nonzero, + input_wrangler=_nonzero_input_wrangler, + ) .xfail( matcher=lambda sample: sample.kwargs.get("as_tuple"), reason="as_tuple=True is not supported", @@ -1058,41 +1270,17 @@ def _where_input_wrangler( nondeterministic=True, ), TorchLibOpInfo("ops.aten.randn", core_ops.aten_randn, nondeterministic=True).xfail( - dtypes=(torch.float16,), reason="fixme: Shape inference error" + dtypes=(torch.float16,), + reason="fixme: Shape inference error", ), TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn_like, nondeterministic=True), TorchLibOpInfo("rad2deg", core_ops.aten_rad2deg), TorchLibOpInfo("reciprocal", core_ops.aten_reciprocal), - TorchLibOpInfo("remainder", core_ops.aten_remainder), - TorchLibOpInfo("repeat", core_ops.aten_repeat), - TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int) - .skip( - matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int), - reason=("ignore cases when repeasts is a Tensor"), - ) - .skip(dtypes=(torch.bool,), reason="bool not supported") - .skip( - matcher=lambda sample: sample.kwargs.get("dim") is None, - reason="fixme: conversion not implemented if dim is None", - ) - .skip( - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: conversion not implemented when input tensor is empty", - ), - TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor) - .skip( - matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int), - reason=("ignore cases when repeasts is an int"), - ) - .skip(dtypes=(torch.bool,), reason="bool not supported") - .skip( - matcher=lambda sample: sample.kwargs.get("dim") is None, - reason="fixme: conversion not implemented if dim is None", - ) - .skip( - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: conversion not implemented when input tensor is empty", + TorchLibOpInfo( + "remainder", + core_ops.aten_remainder, ), + TorchLibOpInfo("repeat", core_ops.aten_repeat), TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg), @@ -1114,9 +1302,14 @@ def _where_input_wrangler( complex=True, ), TorchLibOpInfo( - "ops.aten.scalar_tensor", core_ops.aten_scalar_tensor_complex, complex=True + "ops.aten.scalar_tensor", + core_ops.aten_scalar_tensor_complex, + complex=True, ), - TorchLibOpInfo("scatter_add", core_ops.aten_scatter_add) + TorchLibOpInfo( + "scatter_add", + core_ops.aten_scatter_add, + ) .xfail( matcher=lambda sample: len(sample.input.shape) == 0, reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch. https://github.com/onnx/onnx/issues/4986", @@ -1165,10 +1358,48 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: Tensor-likes are not close. Tests pass for float32.", ), - TorchLibOpInfo("split_with_sizes", core_ops.aten_split_with_sizes), - TorchLibOpInfo("split", core_ops.aten_split), + TorchLibOpInfo( + "split_with_sizes", + core_ops.aten_split_with_sizes, + ) + .xfail( + dtypes=(torch.float16,), + enabled_if=version_utils.onnxruntime_older_than("1.17"), + reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", + ) + .xfail( + dtypes=(torch.bool,), + reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", + ), + TorchLibOpInfo( + "split", + core_ops.aten_split, + ) + .xfail( + dtypes=(torch.float16,), + enabled_if=version_utils.onnxruntime_older_than("1.17"), + reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", + ) + .xfail( + variant_name="list_args", + dtypes=(torch.float16,), + enabled_if=version_utils.onnxruntime_older_than("1.17"), + reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", + ) + .xfail( + dtypes=(torch.bool,), + reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", + ) + .xfail( + variant_name="list_args", + dtypes=(torch.bool,), + reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", + ), TorchLibOpInfo("sqrt", core_ops.aten_sqrt), - TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim) + TorchLibOpInfo( + "squeeze_dim", + core_ops.aten_squeeze_dim, + ) .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", @@ -1178,7 +1409,11 @@ def _where_input_wrangler( and sample.input.shape[sample.args[0]] != 1, reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim_complex, complex=True) + TorchLibOpInfo( + "squeeze_dim", + core_ops.aten_squeeze_dim_complex, + complex=True, + ) .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", @@ -1188,7 +1423,10 @@ def _where_input_wrangler( and sample.input.shape[sample.args[0]] != 1, reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo("squeeze", core_ops.aten_squeeze).skip( + TorchLibOpInfo( + "squeeze", + core_ops.aten_squeeze, + ).skip( matcher=lambda sample: len(sample.args) != 0, reason="this Aten overload only support one tensor as input by design", ), @@ -1197,14 +1435,20 @@ def _where_input_wrangler( TorchLibOpInfo("sub", core_ops.aten_sub, tolerance={torch.float16: (2e-3, 1e-3)}), TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB - TorchLibOpInfo("t", core_ops.aten_t).xfail( + TorchLibOpInfo( + "t", + core_ops.aten_t, + ).xfail( enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING, reason="fixme: ORT Graph attribute inferencing failed on rank-1 input. https://github.com/onnx/onnx/issues/4986", test_class_name="TestOutputConsistencyFullGraph", ), TorchLibOpInfo("tan", core_ops.aten_tan), TorchLibOpInfo("tanh", core_ops.aten_tanh), - TorchLibOpInfo("tile", core_ops.aten_tile).skip( + TorchLibOpInfo( + "tile", + core_ops.aten_tile, + ).skip( matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) or not sample.input.shape, reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", @@ -1232,7 +1476,19 @@ def _where_input_wrangler( reason="fixme: ORT does not have an implementation of Trilu for int32.", ), TorchLibOpInfo("trunc", core_ops.aten_trunc), - TorchLibOpInfo("unbind", core_ops.aten_unbind), + TorchLibOpInfo( + "unbind", + core_ops.aten_unbind, + ) + .xfail( + dtypes=(torch.float16,), + enabled_if=version_utils.onnxruntime_older_than("1.17"), + reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006", + ) + .xfail( + dtypes=(torch.bool,), + reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", + ), TorchLibOpInfo("unflatten", core_ops.aten_unflatten), TorchLibOpInfo("unfold", core_ops.aten_unfold), TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold), @@ -1251,7 +1507,10 @@ def _where_input_wrangler( ), TorchLibOpInfo("xlogy", special_ops.aten_special_xlogy), TorchLibOpInfo("zeros", core_ops.aten_zeros), - TorchLibOpInfo("arange_start_step", core_ops.aten_arange_start_step) + TorchLibOpInfo( + "arange_start_step", + core_ops.aten_arange_start_step, + ) .skip( matcher=lambda sample: len(sample.args) != 2, reason="arange_start_step overload takes three arguments (input, start, step)", @@ -1261,7 +1520,10 @@ def _where_input_wrangler( reason="dtype needs to be specified for non-float tensors", dtypes=(torch.float16, torch.int64, torch.int32), ), - TorchLibOpInfo("arange_start", core_ops.aten_arange_start) + TorchLibOpInfo( + "arange_start", + core_ops.aten_arange_start, + ) .skip( matcher=lambda sample: len(sample.args) != 1, reason="arange_start overload takes two arguments (input, start)", @@ -1271,7 +1533,10 @@ def _where_input_wrangler( reason="dtype needs to be specified for non-float tensors", dtypes=(torch.float16, torch.int64, torch.int32), ), - TorchLibOpInfo("arange", core_ops.aten_arange) + TorchLibOpInfo( + "arange", + core_ops.aten_arange, + ) .xfail( dtypes=(torch.int32,), reason="fixme: output shape mismatch in edge cases. https://github.com/microsoft/onnxscript/issues/974", @@ -1294,7 +1559,10 @@ def _where_input_wrangler( TorchLibOpInfo( "as_strided", core_ops.aten_as_strided, - ).xfail(variant_name="partial_views", reason="ONNX doesn't have partial view for tensor"), + ).xfail( + variant_name="partial_views", + reason="ONNX doesn't have partial view for tensor", + ), TorchLibOpInfo("clamp", core_ops.aten_clamp_tensor), TorchLibOpInfo( "ops.aten.col2im", @@ -1314,13 +1582,19 @@ def _where_input_wrangler( tolerance={torch.float32: (2e-4, 9e-4)}, ), TorchLibOpInfo("empty_like", core_ops.aten_empty_like, nondeterministic=True), - TorchLibOpInfo("grid_sampler_2d", core_ops.aten_grid_sampler_2d) + TorchLibOpInfo( + "grid_sampler_2d", + core_ops.aten_grid_sampler_2d, + ) .skip( # Torch implemented this using the cubic convolution algorithm with alhpa=-0.75, might be different than ORT matcher=lambda sample: sample.args[1] == 2, reason="fixme: 'bicubic' mode in ORT implemented differently with Torch", ) - .skip(dtypes=(torch.float16,), reason="fixme: Accuracy is not high enough"), + .skip( + dtypes=(torch.float16,), + reason="fixme: Accuracy is not high enough", + ), TorchLibOpInfo( "nn.functional.group_norm", nn_ops.aten_group_norm, @@ -1364,10 +1638,6 @@ def _where_input_wrangler( dtypes=(torch.float32 if sys.platform != "linux" else torch.complex64,), reason="fixme: test is unstable on macosx, windows", ), - TorchLibOpInfo("logical_and", core_ops.aten_logical_and), - TorchLibOpInfo("logical_not", core_ops.aten_logical_not), - TorchLibOpInfo("logical_or", core_ops.aten_logical_or), - TorchLibOpInfo("logical_xor", core_ops.aten_logical_xor), TorchLibOpInfo("logit", core_ops.aten_logit, tolerance={torch.float16: (1e-1, 7e-4)}), TorchLibOpInfo("max_dim", core_ops.aten_max_dim) .xfail( @@ -1381,7 +1651,10 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo("max", core_ops.aten_max).skip( + TorchLibOpInfo( + "max", + core_ops.aten_max, + ).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), @@ -1439,7 +1712,8 @@ def _where_input_wrangler( reason="fixme: ORT only supports BatchNorm less than opset14", ), TorchLibOpInfo( - "ops.aten._native_batch_norm_legit.no_stats", core_ops.aten__native_batch_norm_no_stats + "ops.aten._native_batch_norm_legit.no_stats", + core_ops.aten__native_batch_norm_no_stats, ), TorchLibOpInfo( "ops.aten._native_batch_norm_legit_functional", @@ -1460,6 +1734,10 @@ def _where_input_wrangler( "ops.aten.native_group_norm", core_ops.aten_native_group_norm, tolerance={torch.float16: (1e-2, 7e-3)}, + ).xfail( + dtypes=(torch.float16,), + reason="fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly", + enabled_if=version_utils.torch_older_than("2.2"), ), TorchLibOpInfo( "native_layer_norm", @@ -1541,7 +1819,9 @@ def _where_input_wrangler( tolerance={torch.float16: (1e-2, 1e-3)}, ), TorchLibOpInfo( - "ops.aten.conv3d", core_ops.aten_conv3d, tolerance={torch.float32: (3.7e-5, 1.8e-4)} + "ops.aten.conv3d", + core_ops.aten_conv3d, + tolerance={torch.float32: (3.7e-5, 1.8e-4)}, ), TorchLibOpInfo("nn.functional.gelu", nn_ops.aten_gelu), TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), @@ -1622,6 +1902,11 @@ def _where_input_wrangler( nn_ops.aten_scaled_dot_product_attention, tolerance={torch.float32: (3e-4, 1.5e-5)}, ) + .skip( + matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None + and attn_mask.dtype == torch.bool, + reason="this overload takes a non-boolean mask", + ) .skip( matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, reason="dropout is random so the results do not match", @@ -1630,12 +1915,6 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", test_class_name="TestOutputConsistencyFullGraph", - ) - .xfail( - matcher=lambda sample: len(sample.input.shape) != 4 - or len(sample.args[0].shape) != 4 - or len(sample.args[1].shape) != 4, - reason="torch sdpa is expected to pass in 4d q, k, and v.", ), TorchLibOpInfo( "ops.aten._scaled_dot_product_flash_attention", @@ -1644,7 +1923,15 @@ def _where_input_wrangler( # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3, 4, 5, 6, 7, 8), - ).skip(device_type="cpu", reason="_scaled_dot_product_flash_attention only supports CUDA"), + ) + .skip( + enabled_if=version_utils.torch_older_than("2.1"), + reason="The operator is not supported in older version.", + ) + .skip( + device_type="cpu", + reason="_scaled_dot_product_flash_attention only supports CUDA", + ), TorchLibOpInfo( "ops.aten._scaled_dot_product_efficient_attention", nn_ops.aten__scaled_dot_product_efficient_attention, @@ -1652,10 +1939,34 @@ def _where_input_wrangler( # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3), - ).skip( + ) + .skip( + enabled_if=version_utils.torch_older_than("2.1"), + reason="The operator is not supported in older version.", + ) + .skip( enabled_if=not torch.cuda.is_available(), reason="_scaled_dot_product_efficient_attention only supports CUDA", ), + TorchLibOpInfo( + "nn.functional.scaled_dot_product_attention_bool_mask", + nn_ops.aten_scaled_dot_product_attention_bool_mask, + tolerance={torch.float32: (3e-4, 1.5e-5)}, + ) + .skip( + matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None + and attn_mask.dtype != torch.bool, + reason="this overload takes a boolean mask", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, + reason="dropout is random so the results do not match", + ) + .xfail( + dtypes=(torch.float16,), + reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", + test_class_name="TestOutputConsistencyFullGraph", + ), TorchLibOpInfo( "ops.aten.upsample_bilinear2d.default", nn_ops.aten_upsample_bilinear2d, @@ -1675,7 +1986,10 @@ def _where_input_wrangler( # Shape-only comparison is the appropriate testing approach for this case. compare_shape_only_for_output=(0,), ), - TorchLibOpInfo("ops.aten.upsample_bilinear2d.vec", nn_ops.aten_upsample_bilinear2d_vec), + TorchLibOpInfo( + "ops.aten.upsample_bilinear2d.vec", + nn_ops.aten_upsample_bilinear2d_vec, + ), TorchLibOpInfo( "ops.aten.upsample_bicubic2d.default", nn_ops.aten_upsample_bicubic2d, @@ -1695,7 +2009,10 @@ def _where_input_wrangler( # Shape-only comparison is the appropriate testing approach for this case. compare_shape_only_for_output=(0,), ), - TorchLibOpInfo("ops.aten.upsample_bicubic2d.vec", nn_ops.aten_upsample_bicubic2d_vec), + TorchLibOpInfo( + "ops.aten.upsample_bicubic2d.vec", + nn_ops.aten_upsample_bicubic2d_vec, + ), TorchLibOpInfo( "ops.aten.upsample_linear1d", nn_ops.aten_upsample_linear1d, @@ -1704,14 +2021,38 @@ def _where_input_wrangler( and sample.kwargs.get("scales") is not None, reason="fixme: align_corners=False output mismatch when scales are provided", ), - TorchLibOpInfo("ops.aten.upsample_nearest1d", nn_ops.aten_upsample_nearest1d), - TorchLibOpInfo("ops.aten.upsample_nearest1d.vec", nn_ops.aten_upsample_nearestnd_vec), - TorchLibOpInfo("ops.aten.upsample_nearest2d", nn_ops.aten_upsample_nearest2d), - TorchLibOpInfo("ops.aten.upsample_nearest2d.vec", nn_ops.aten_upsample_nearestnd_vec), - TorchLibOpInfo("ops.aten.upsample_nearest3d", nn_ops.aten_upsample_nearest3d), - TorchLibOpInfo("ops.aten.upsample_nearest3d.vec", nn_ops.aten_upsample_nearestnd_vec), - TorchLibOpInfo("ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d), - TorchLibOpInfo("ops.aten.upsample_trilinear3d.vec", nn_ops.aten_upsample_trilinear3d_vec), + TorchLibOpInfo( + "ops.aten.upsample_nearest1d", + nn_ops.aten_upsample_nearest1d, + ), + TorchLibOpInfo( + "ops.aten.upsample_nearest1d.vec", + nn_ops.aten_upsample_nearestnd_vec, + ), + TorchLibOpInfo( + "ops.aten.upsample_nearest2d", + nn_ops.aten_upsample_nearest2d, + ), + TorchLibOpInfo( + "ops.aten.upsample_nearest2d.vec", + nn_ops.aten_upsample_nearestnd_vec, + ), + TorchLibOpInfo( + "ops.aten.upsample_nearest3d", + nn_ops.aten_upsample_nearest3d, + ), + TorchLibOpInfo( + "ops.aten.upsample_nearest3d.vec", + nn_ops.aten_upsample_nearestnd_vec, + ), + TorchLibOpInfo( + "ops.aten.upsample_trilinear3d.default", + nn_ops.aten_upsample_trilinear3d, + ), + TorchLibOpInfo( + "ops.aten.upsample_trilinear3d.vec", + nn_ops.aten_upsample_trilinear3d_vec, + ), TorchLibOpInfo("ones_like", core_ops.aten_ones_like), TorchLibOpInfo( "roll", @@ -1729,7 +2070,10 @@ def _where_input_wrangler( core_ops.aten_scatter_reduce, input_wrangler=_scatter_reduce_input_wrangler, ) - .xfail(variant_name="mean", reason="ONNX doesn't support reduce='mean' option") + .xfail( + variant_name="mean", + reason="ONNX doesn't support reduce='mean' option", + ) .xfail( variant_name="prod", dtypes=(torch.float16, torch.float64), @@ -1756,8 +2100,6 @@ def _where_input_wrangler( reason="onnxruntime does not support ml_dtypes.bfloat16", ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), - TorchLibOpInfo("ops.aten.scatter.src", core_ops.aten_scatter_src), - TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value), TorchLibOpInfo("slice", core_ops.aten_slice), TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True), TorchLibOpInfo( @@ -1789,7 +2131,6 @@ def _where_input_wrangler( "Our implementation is based on that for CUDA" ), ), - TorchLibOpInfo("ops.prims.broadcast_in_dim.default", prims_ops.prims_broadcast_in_dim), TorchLibOpInfo( "ops.prims.var.default", prims_ops.prims_var, tolerance={torch.float16: (1e-3, 5e-2)} ), @@ -1803,13 +2144,40 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",)) +ops_test_common.duplicate_opinfo( + OPS_DB, + "bitwise_left_shift", + ( + "bitwise_left_shift_int8", + "bitwise_left_shift_int16", + "bitwise_left_shift_int32", + "bitwise_left_shift_int64", + ), +) +ops_test_common.duplicate_opinfo( + OPS_DB, + "bitwise_right_shift", + ( + "bitwise_right_shift_int8", + "bitwise_right_shift_int16", + "bitwise_right_shift_int32", + "bitwise_right_shift_int64", + ), +) ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) -ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",)) +ops_test_common.duplicate_opinfo(OPS_DB, "diagonal", ("diagonal_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode", "div_mode_int")) +ops_test_common.duplicate_opinfo(OPS_DB, "ge", ("ge_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "gt", ("gt_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "le", ("le_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "lt", ("lt_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) +ops_test_common.duplicate_opinfo(OPS_DB, "maximum", ("maximum_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) +ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",)) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.pad", @@ -1819,6 +2187,20 @@ def _where_input_wrangler( "nn.functional.replication_pad3d", ), ) +ops_test_common.duplicate_opinfo( + OPS_DB, + "nn.functional.scaled_dot_product_attention", + ("nn.functional.scaled_dot_product_attention_bool_mask",), +) +ops_test_common.duplicate_opinfo( + OPS_DB, + "nn.functional.celu", + ("nn.functional.celu_type_promoted",), +) +ops_test_common.duplicate_opinfo( + OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",) +) +ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))