Skip to content

Commit 7f3325b

Browse files
authored
support for scalar args to aten::scatter (#2613)
close #2600 Signed-off-by: Linsho Kaku <linsho@preferred.jp>
1 parent 6718ef0 commit 7f3325b

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7744,6 +7744,10 @@ def aten_scatter_src(
77447744
src: TTensor,
77457745
) -> TTensor:
77467746
"""scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
7747+
if len(index.shape) == 0:
7748+
index = op.Unsqueeze(index, [0])
7749+
if len(src.shape) == 0:
7750+
src = op.Unsqueeze(src, [0])
77477751
return op.ScatterElements(self, index, src, axis=dim)
77487752

77497753

@@ -7756,6 +7760,8 @@ def aten_scatter_value(
77567760
) -> TTensor:
77577761
"""scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"""
77587762
# Ensure value is a scalar tensor and expand it to match index shape
7763+
if len(index.shape) == 0:
7764+
index = op.Unsqueeze(index, [0])
77597765
scalar_tensor = ir.tensor([value], dtype=self.dtype)
77607766
src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor)
77617767
return op.ScatterElements(self, index, src, axis=dim)

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,35 @@ def sample_inputs_scatter_src(op_info, device, dtype, requires_grad, **kwargs):
13941394
src_tensor = make_arg(src_shape)
13951395
yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, src_tensor))
13961396

1397+
# Additional test cases for scalar and single-element tensor combinations with dim=0
1398+
# Test case: scalar index, scalar src (dim_size=5)
1399+
dim_size = 5
1400+
data_1d = make_arg((dim_size,))
1401+
valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long)
1402+
scalar_src = make_arg(())
1403+
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, scalar_src))
1404+
1405+
# Test case: single-element tensor index, scalar src (dim_size=7)
1406+
dim_size = 7
1407+
data_1d = make_arg((dim_size,))
1408+
valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long)
1409+
scalar_src = make_arg(())
1410+
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, scalar_src))
1411+
1412+
# Test case: scalar index, single-element tensor src (dim_size=3)
1413+
dim_size = 3
1414+
data_1d = make_arg((dim_size,))
1415+
valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long)
1416+
src_1d = make_arg((1,))
1417+
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, src_1d))
1418+
1419+
# Test case: single-element tensor index, single-element tensor src (dim_size=10)
1420+
dim_size = 10
1421+
data_1d = make_arg((dim_size,))
1422+
valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long)
1423+
src_1d = make_arg((1,))
1424+
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, src_1d))
1425+
13971426

13981427
def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs):
13991428
del op_info
@@ -1423,6 +1452,21 @@ def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs)
14231452
]
14241453
yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, value))
14251454

1455+
# Additional test cases for scalar and single-element tensor combinations with dim=0
1456+
# Test case: scalar index with scalar value (dim_size=6, value_type=torch.long)
1457+
dim_size = 6
1458+
data_1d = make_arg((dim_size,))
1459+
valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long)
1460+
random_value = torch.randint(0, 10, (), device=device, dtype=torch.long).item()
1461+
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, random_value))
1462+
1463+
# Test case: single-element tensor index with scalar value (dim_size=8, value_type=torch.float)
1464+
dim_size = 8
1465+
data_1d = make_arg((dim_size,))
1466+
valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long)
1467+
random_value = torch.rand((), device=device, dtype=torch.float).item()
1468+
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, random_value))
1469+
14261470

14271471
def sample_inputs__scaled_dot_product_flash_attention(
14281472
op_info, device, dtype, requires_grad, **kwargs

0 commit comments

Comments
 (0)