Skip to content

Commit 6718ef0

Browse files
authored
Enhanced type annotations and simplified implementation of scatter.value (#2612)
follow #2605 --------- Signed-off-by: Linsho Kaku <linsho@preferred.jp>
1 parent aa2cf4a commit 6718ef0

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7738,26 +7738,26 @@ def aten_scalar_tensor_sym_number(
77387738

77397739
@torch_op("aten::scatter.src", trace_only=True)
77407740
def aten_scatter_src(
7741-
self: TReal,
7741+
self: TTensor,
77427742
dim: int, # we have to use int here because ScatterElements() will use this attribute
77437743
index: TInt,
7744-
src: TReal,
7745-
) -> TReal:
7744+
src: TTensor,
7745+
) -> TTensor:
77467746
"""scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
77477747
return op.ScatterElements(self, index, src, axis=dim)
77487748

77497749

77507750
@torch_op("aten::scatter.value", trace_only=True)
77517751
def aten_scatter_value(
7752-
self: TReal,
7752+
self: TTensor,
77537753
dim: int, # we have to use int here because ScatterElements() will use this attribute
77547754
index: TInt,
7755-
value: TReal,
7756-
) -> TReal:
7755+
value: float,
7756+
) -> TTensor:
77577757
"""scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"""
77587758
# Ensure value is a scalar tensor and expand it to match index shape
7759-
scalar_tensor = op.CastLike(value, self)
7760-
src = op.Expand(scalar_tensor, op.Shape(index))
7759+
scalar_tensor = ir.tensor([value], dtype=self.dtype)
7760+
src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor)
77617761
return op.ScatterElements(self, index, src, axis=dim)
77627762

77637763

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,9 +1407,9 @@ def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs)
14071407
# (self_shape, index_shape, dim, value)
14081408
((5, 5), (2, 3), 0, 1.0), # 2D scatter on dim=0 with scalar value
14091409
((5, 5), (3, 2), 1, -2.5), # 2D scatter on dim=1 with scalar value
1410-
((3, 4, 5), (2, 2, 3), 0, 0.0), # 3D scatter on dim=0 with scalar value
1410+
((3, 4, 5), (2, 2, 3), 0, False), # 3D scatter on dim=0 with scalar value
14111411
((3, 4, 5), (2, 2, 3), 1, 3.14), # 3D scatter on dim=1 with scalar value
1412-
((3, 4, 5), (2, 2, 3), 2, -1.0), # 3D scatter on dim=2 with scalar value
1412+
((3, 4, 5), (2, 2, 3), 2, -1), # 3D scatter on dim=2 with scalar value
14131413
((10,), (3,), 0, 5.0), # 1D scatter with scalar value
14141414
]
14151415

0 commit comments

Comments
 (0)