Skip to content

Commit b6a2d02

Browse files
authored
Remove redundant registration of operator::add (#2631)
I forgot to remove the previous registration.
1 parent f44b314 commit b6a2d02

File tree

1 file changed

+18
-4
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+18
-4
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def aten_acosh(self: TFloat) -> TFloat:
132132
return op.Acosh(self)
133133

134134

135-
@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True)
135+
@torch_op("aten::add.Tensor", trace_only=True)
136136
def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor:
137137
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
138138

@@ -148,7 +148,15 @@ def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor:
148148
return op.Add(self, other)
149149

150150

151-
@torch_op(("_operator::add"), trace_only=True)
151+
@torch_op("aten::add.Scalar", trace_only=True)
152+
def aten_add_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor:
153+
"""add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"""
154+
155+
other = op.Constant(value=ir.tensor(other, dtype=self.dtype))
156+
return aten_add(self, other, alpha=alpha)
157+
158+
159+
@torch_op("_operator::add", trace_only=True)
152160
def operator_add(self: TTensor, other: TTensor) -> TTensor:
153161
return op.Add(self, other)
154162

@@ -8113,9 +8121,7 @@ def aten_std_mean_correction(
81138121
@torch_op(
81148122
(
81158123
"aten::sub.Tensor",
8116-
"aten::sub.Scalar",
81178124
"aten::subtract.Tensor",
8118-
"aten::subtract.Scalar",
81198125
"_operator::sub",
81208126
),
81218127
trace_only=True,
@@ -8128,6 +8134,14 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
81288134
return op.Sub(self, other)
81298135

81308136

8137+
@torch_op(("aten::sub.Scalar", "aten::subtract.Scalar"), trace_only=True)
8138+
def aten_sub_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor:
8139+
"""sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"""
8140+
8141+
other = op.Constant(value=ir.tensor(other, dtype=self.dtype))
8142+
return aten_sub(self, other, alpha=alpha)
8143+
8144+
81318145
@torch_op(
81328146
(
81338147
"aten::sub.Tensor",

0 commit comments

Comments
 (0)