Skip to content

Commit a1be5c8

Browse files
authored
[torchlib] Fix mod on SymInt (#2686)
Fix the error: `<class 'AttributeError'>: 'int' object has no attribute 'dtype'` by splitting the implementation out for `operator.*` ops. Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 70e751a commit a1be5c8

File tree

1 file changed

+35
-7
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+35
-7
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3732,7 +3732,7 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:
37323732

37333733

37343734
@torch_op(
3735-
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"),
3735+
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor"),
37363736
trace_only=True,
37373737
)
37383738
def aten_ge(self: TTensor, other: TTensor) -> BOOL:
@@ -3749,6 +3749,12 @@ def aten_ge(self: TTensor, other: TTensor) -> BOOL:
37493749
return op.GreaterOrEqual(self, other)
37503750

37513751

3752+
@torch_op("_operator::ge", trace_only=True)
3753+
def operator_ge(self: TTensor, other: TTensor) -> BOOL:
3754+
# operator.ge for SymInt
3755+
return op.GreaterOrEqual(self, other)
3756+
3757+
37523758
def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]:
37533759
"""geqrf(Tensor self) -> (Tensor a, Tensor tau)"""
37543760

@@ -4058,7 +4064,7 @@ def aten_gru_cell(
40584064

40594065

40604066
@torch_op(
4061-
("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"),
4067+
("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor"),
40624068
trace_only=True,
40634069
)
40644070
def aten_gt(self: TTensor, other: TTensor) -> BOOL:
@@ -4076,6 +4082,12 @@ def aten_gt(self: TTensor, other: TTensor) -> BOOL:
40764082
return op.Greater(self, other)
40774083

40784084

4085+
@torch_op("_operator::gt", trace_only=True)
4086+
def operator_gt(self: TTensor, other: TTensor) -> BOOL:
4087+
# operator.gt for SymInt
4088+
return op.Greater(self, other)
4089+
4090+
40794091
@torch_op("aten::hamming_window", trace_only=True)
40804092
def aten_hamming_window(
40814093
window_length: int,
@@ -4891,7 +4903,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
48914903

48924904

48934905
@torch_op(
4894-
("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"),
4906+
("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor"),
48954907
trace_only=True,
48964908
)
48974909
def aten_le(self: TTensor, other: TTensor) -> BOOL:
@@ -4909,6 +4921,12 @@ def aten_le(self: TTensor, other: TTensor) -> BOOL:
49094921
return op.LessOrEqual(self, other)
49104922

49114923

4924+
@torch_op("_operator::le", trace_only=True)
4925+
def operator_le(self: TTensor, other: TTensor) -> BOOL:
4926+
# operator.le for SymInt
4927+
return op.LessOrEqual(self, other)
4928+
4929+
49124930
@torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar"))
49134931
def aten_lerp(self: TTensor, end: TTensor, weight: TTensor) -> TTensor:
49144932
"""lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"""
@@ -5384,7 +5402,7 @@ def aten_lstm(
53845402

53855403

53865404
@torch_op(
5387-
("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"),
5405+
("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor"),
53885406
trace_only=True,
53895407
)
53905408
def aten_lt(self: TTensor, other: TTensor) -> BOOL:
@@ -5401,6 +5419,12 @@ def aten_lt(self: TTensor, other: TTensor) -> BOOL:
54015419
return op.Less(self, other)
54025420

54035421

5422+
@torch_op("_operator::lt", trace_only=True)
5423+
def operator_lt(self: TTensor, other: TTensor) -> BOOL:
5424+
# operator.lt for SymInt
5425+
return op.Less(self, other)
5426+
5427+
54045428
def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType:
54055429
"""lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor"""
54065430

@@ -7468,9 +7492,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
74687492
raise NotImplementedError()
74697493

74707494

7471-
@torch_op(
7472-
("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True
7473-
)
7495+
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True)
74747496
def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
74757497
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
74767498

@@ -7486,6 +7508,12 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
74867508
return op.Sub(self, op.Mul(rounded_quotient, other))
74877509

74887510

7511+
@torch_op("_operator::mod", trace_only=True)
7512+
def operator_mod(self: TTensor, other: TTensor) -> TTensor:
7513+
# Modulus operator % on SymInt
7514+
return op.Mod(self, other)
7515+
7516+
74897517
def aten_rename(self: TensorType, names: Optional[str]) -> TensorType:
74907518
"""rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)"""
74917519

0 commit comments

Comments
 (0)