From 3cb493f871acfe164bae5b6b7790219936860348 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Oct 2025 17:55:29 +0200 Subject: [PATCH 01/15] Implements aten_index_put if inputs are SymbolicTensor --- .../function_libs/torch_lib/ops/core.py | 105 +++++++++++++++++- .../function_libs/torch_lib/e2e_ops_tests.py | 43 +++++++ 2 files changed, 143 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1a688a4277..08d4fab280 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4383,6 +4383,11 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ + if any( + isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) + for indice in indices + ): + return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): # Remove ones until the rank of reshape_list matches values_shape. @@ -4452,14 +4457,104 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): # Flatten values to match the indices flat_values = op.Reshape(values, [-1]) - if accumulate: - result = op.ScatterND(self, new_index, flat_values, reduction="add") - else: - result = op.ScatterND(self, new_index, flat_values) - + scatter_kwargs = dict(reduction="add") if accumulate else {} + result = op.ScatterND(self, new_index, flat_values, **scatter_kwargs) return result +def _aten_index_put_dynamic( + x: TReal, + indices: Sequence[INT64], + values: TReal, + accumulate: bool = False, +) -> TReal: + def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): + if ind is not None: + return op.Cast(ind, to=INT64.dtype), False + return ( + op.Cast( + op.Range( # Range does not return a typed result + 0, + op.Squeeze(op.Shape(x, start=dim, end=dim + 1)), + 1, + ), + to=INT64.dtype, + ), + True, + ) + + rk1s = [(ind is None or len(ind.shape) == 1) for ind in indices] + assert all(rk1s) and len(rk1s) == len(x.shape), ( + f"input_put not implemented for indices={indices}, " + f"where rk1s={rk1s}, rank(x)={len(x.shape)}" + ) + shape_x = op.Shape(x) + exped = [] + fixed = [] + reshape_value_shape2 = [] + expand_value_shape = [] + for i, ind in enumerate(indices): + if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): + ind.dtype = ir.DataType.INT64 + ind, expanded = _make_range_or_cast(ind, shape_x, False, i) + if expanded: + exped.append((i, ind)) + expand_value_shape.append(op.Shape(x, start=i, end=i + 1)) + reshape_value_shape2.append([1]) + else: + expand_value_shape.append([1]) + reshape_value_shape2.append(op.Shape(ind)) + fixed.append((i, ind)) + + reshape_value_shape1 = [1] * len(indices) + if len(fixed) <= 1: + reshape_value_shape1 = None + elif fixed: + reshape_value_shape1[fixed[-1][0]] = -1 + + def _mkstride(x, i): + if i >= len(x.shape) - 1: + return [1] + if i == len(x.shape) - 2: + return op.Shape(x, start=i + 1) + return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1) + + shape = [1] * (len(x.shape) + 1) + mfixed = [] + if fixed: + new_shape = shape.copy() + new_shape[-1] = -1 + mfixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed] + + mexped = [] + for i, e in exped: + new_shape = shape.copy() + new_shape[i] = -1 + mexped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape)) + + # final sum + unflat = None + for a in [*mfixed, *mexped]: + if unflat is None: + unflat = a + continue + unflat = op.Add(unflat, a) + + # value_shape + expanded_values = values + if reshape_value_shape1 is not None: + expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0)) + # Bug here: Error calling operator 'Concat' with args + # (SymbolicTensor(name='anonymous:124529632436112', producer=anonymous_node:124529631522416, index=0), [1], [1]) + expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0)) + flat_ind = op.Reshape(unflat, [-1]) + expanded_values = op.Reshape(expanded_values, [-1]) + flat_x = op.Reshape(x, [-1]) + scat_kwargs = {"reduction": "add"} if accumulate else {} + flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs) + return op.Reshape(flat_up_x, op.Shape(x)) + + @torch_op("aten::index_put", trace_only=True) def aten_index_put_bool( self: TReal, diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 1b0410c27f..437017af97 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -225,6 +225,49 @@ def forward(self, q, k, v): ) _testing.assert_onnx_program(onnx_program) + def test_index_put_dynamic(self): + for dimension in [3, 4, 2]: + with self.subTest(dimension=dimension): + + class Model(torch.nn.Module): + def __init__(self, dimension): + super().__init__() + self.params = torch.zeros( + (4, 5) + if dimension == 2 + else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5)) + ) + self.dimension = dimension + + def forward(self, update, index1, index2): + copy = self.params.clone() + if self.dimension == 2: + copy[index1, index2] = update + elif self.dimension == 3: + copy[:, index1, index2] = update + else: + copy[:, :, index1, index2] = update + return copy + + update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32) + index1 = torch.tensor([1, 2], dtype=torch.int64) + index2 = torch.tensor([3, 4], dtype=torch.int64) + feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2))) + onnx_program = torch.onnx.export( + Model(dimension), + tuple(feeds.values()), + input_names=["update", "index1", "index2"], + output_names=["output"], + opset_version=18, + dynamo=True, + dynamic_shapes={ + "update": {0: "dn"}, + "index1": {0: "dn"}, + "index2": {0: "dn"}, + }, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From c9472f3bfcf9eb42aa7818de1deb86851fde23f1 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Oct 2025 18:27:21 +0200 Subject: [PATCH 02/15] disable case with one index --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 08d4fab280..aefce5b038 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4383,7 +4383,7 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ - if any( + if len(indices) > 1 and any( isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) for indice in indices ): From 434fcfb0130ae099a78964fca1d83f38220b8254 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Oct 2025 19:15:44 +0200 Subject: [PATCH 03/15] type constant --- .../function_libs/torch_lib/ops/core.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index aefce5b038..e434529d81 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4468,15 +4468,21 @@ def _aten_index_put_dynamic( values: TReal, accumulate: bool = False, ) -> TReal: + def _1dint(i: int): + return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i])) + + def _0dint(i: int): + return op.Constant(value_int=ir.AttrInt64("value_int", i)) + def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): if ind is not None: return op.Cast(ind, to=INT64.dtype), False return ( op.Cast( op.Range( # Range does not return a typed result - 0, + _0dint(0), op.Squeeze(op.Shape(x, start=dim, end=dim + 1)), - 1, + _0dint(1), ), to=INT64.dtype, ), @@ -4500,21 +4506,21 @@ def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): if expanded: exped.append((i, ind)) expand_value_shape.append(op.Shape(x, start=i, end=i + 1)) - reshape_value_shape2.append([1]) + reshape_value_shape2.append(_1dint(1)) else: - expand_value_shape.append([1]) + expand_value_shape.append(_1dint(1)) reshape_value_shape2.append(op.Shape(ind)) fixed.append((i, ind)) - reshape_value_shape1 = [1] * len(indices) + reshape_value_shape1 = [_1dint(1)] * len(indices) if len(fixed) <= 1: reshape_value_shape1 = None elif fixed: - reshape_value_shape1[fixed[-1][0]] = -1 + reshape_value_shape1[fixed[-1][0]] = _1dint(-1) def _mkstride(x, i): if i >= len(x.shape) - 1: - return [1] + return _1dint(1) if i == len(x.shape) - 2: return op.Shape(x, start=i + 1) return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1) @@ -4547,9 +4553,9 @@ def _mkstride(x, i): # Bug here: Error calling operator 'Concat' with args # (SymbolicTensor(name='anonymous:124529632436112', producer=anonymous_node:124529631522416, index=0), [1], [1]) expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0)) - flat_ind = op.Reshape(unflat, [-1]) - expanded_values = op.Reshape(expanded_values, [-1]) - flat_x = op.Reshape(x, [-1]) + flat_ind = op.Reshape(unflat, _1dint(-1)) + expanded_values = op.Reshape(expanded_values, _1dint(-1)) + flat_x = op.Reshape(x, _1dint(-1)) scat_kwargs = {"reduction": "add"} if accumulate else {} flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs) return op.Reshape(flat_up_x, op.Shape(x)) From ae6adca76ea8a6286053a07534ecb4dbd1a840e4 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 7 Oct 2025 10:15:39 +0200 Subject: [PATCH 04/15] another fix --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ecca0de8e5..45adf39df3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4386,7 +4386,7 @@ def aten_index_put( if len(indices) > 1 and any( isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) for indice in indices - ): + ) and len(values.shape) == 1: return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): From 851036421aa9373039cd019405f5a9fc179c7a83 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 7 Oct 2025 10:31:58 +0200 Subject: [PATCH 05/15] rename --- .../function_libs/torch_lib/ops/core.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 45adf39df3..330cf02ef0 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4383,10 +4383,14 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ - if len(indices) > 1 and any( - isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) - for indice in indices - ) and len(values.shape) == 1: + if ( + len(indices) > 1 + and any( + isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) + for indice in indices + ) + and len(values.shape) == 1 + ): return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): @@ -4526,21 +4530,21 @@ def _mkstride(x, i): return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1) shape = [1] * (len(x.shape) + 1) - mfixed = [] + r_fixed = [] if fixed: new_shape = shape.copy() new_shape[-1] = -1 - mfixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed] + r_fixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed] - mexped = [] + r_exped = [] for i, e in exped: new_shape = shape.copy() new_shape[i] = -1 - mexped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape)) + r_exped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape)) # final sum unflat = None - for a in [*mfixed, *mexped]: + for a in [*r_fixed, *r_exped]: if unflat is None: unflat = a continue @@ -4550,8 +4554,6 @@ def _mkstride(x, i): expanded_values = values if reshape_value_shape1 is not None: expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0)) - # Bug here: Error calling operator 'Concat' with args - # (SymbolicTensor(name='anonymous:124529632436112', producer=anonymous_node:124529631522416, index=0), [1], [1]) expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0)) flat_ind = op.Reshape(unflat, _1dint(-1)) expanded_values = op.Reshape(expanded_values, _1dint(-1)) From e4d574a4025a5e4137985c5963164b0e785f3d27 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 7 Oct 2025 11:27:44 +0200 Subject: [PATCH 06/15] style --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 330cf02ef0..3d38c7d806 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4494,10 +4494,6 @@ def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): ) rk1s = [(ind is None or len(ind.shape) == 1) for ind in indices] - assert all(rk1s) and len(rk1s) == len(x.shape), ( - f"input_put not implemented for indices={indices}, " - f"where rk1s={rk1s}, rank(x)={len(x.shape)}" - ) shape_x = op.Shape(x) exped = [] fixed = [] From 86d482d588ac58e5c0df8a3902cb7ee255296537 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 7 Oct 2025 11:47:22 +0200 Subject: [PATCH 07/15] lint --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 3d38c7d806..078b4207de 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4386,7 +4386,7 @@ def aten_index_put( if ( len(indices) > 1 and any( - isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) + isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) # pylint: disable=protected-access for indice in indices ) and len(values.shape) == 1 @@ -4493,14 +4493,13 @@ def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): True, ) - rk1s = [(ind is None or len(ind.shape) == 1) for ind in indices] shape_x = op.Shape(x) exped = [] fixed = [] reshape_value_shape2 = [] expand_value_shape = [] for i, ind in enumerate(indices): - if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): + if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): # pylint: disable=protected-access ind.dtype = ir.DataType.INT64 ind, expanded = _make_range_or_cast(ind, shape_x, False, i) if expanded: From e108dc3d0aef7bf3fcc38e00e093e3636b21523f Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 16 Oct 2025 14:58:22 +0200 Subject: [PATCH 08/15] rename --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 65cec71d37..476daf2b24 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4226,8 +4226,8 @@ def aten_index_put( if ( len(indices) > 1 and any( - isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) # pylint: disable=protected-access - for indice in indices + isinstance(index, torch.onnx._internal.exporter._tensors.SymbolicTensor) # pylint: disable=protected-access + for index in indices ) and len(values.shape) == 1 ): From 988e9f6f0f638bf186afd90adf75476e0490d3e8 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Oct 2025 15:13:39 +0200 Subject: [PATCH 09/15] handle one more case for index_put --- .../function_libs/torch_lib/ops/core.py | 36 +++++++++++++++++++ .../function_libs/torch_lib/e2e_ops_tests.py | 25 +++++++++++++ 2 files changed, 61 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 476daf2b24..130de462ab 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4233,6 +4233,14 @@ def aten_index_put( ): return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate) + n_none = [i for i, ind in enumerate(indices) if ind is not None] + if ( + len(n_none) == 1 + and len(indices[n_none[0]].shape) == 1 + and len(self.shape) == len(values.shape) + ): + return _aten_index_put_scatter_nd(self, indices, values, accumulate) + def _make_reshape_list_broadcastable(reshape_list, values_shape): # Remove ones until the rank of reshape_list matches values_shape. while len(reshape_list) > len(values_shape) and 1 in reshape_list: @@ -4306,6 +4314,34 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): return result +def _aten_index_put_scatter_nd( + x: TReal, + indices: Sequence[INT64], + values: TReal, + accumulate: bool = False, +) -> TReal: + def _1dint(i: int): + return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i])) + + n_none = [i for i, ind in enumerate(indices) if ind is not None] + assert len(n_none) == 1, f"Unable to handle that case: n_none={n_none}" + unsq = op.Unsqueeze(indices[n_none[0]], _1dint(1)) + if n_none[0] == 0: + return op.ScatterND(x, unsq, values) + + perm = list(range(len(x.shape))) + perm[n_none[0]], perm[0] = perm[0], perm[n_none[0]] + return op.Transpose( + op.ScatterND( + op.Transpose(x, perm=perm), + unsq, + op.Transpose(values, perm=perm), + reduction="add" if accumulate else "none", + ), + perm=perm, + ) + + def _aten_index_put_dynamic( x: TReal, indices: Sequence[INT64], diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 2402d024c7..26629cd0fa 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -5,6 +5,7 @@ import unittest +import numpy as np import torch from torch.onnx._internal.exporter import _testing @@ -268,6 +269,30 @@ def forward(self, update, index1, index2): ) _testing.assert_onnx_program(onnx_program) + def test_index_put_scatter_nd(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + x = x.clone() + return torch.ops.aten.index_put(x, [None, index, None], update) + + shape = (2, 3, 2) + N = int(np.prod(shape)) + x = torch.arange(N, dtype=torch.float32).reshape(shape) + update = (torch.arange(N, dtype=torch.float32).reshape(shape) + 1) * 100 + index = ((torch.arange(shape[-2])).to(torch.int64) + 1) % shape[-2] + + feeds = dict(zip(["x", "index", "update"], (x, index, update))) + onnx_program = torch.onnx.export( + Model(), + tuple(feeds.values()), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + dynamic_shapes=({0: "a", 1: "b", 2: "c"}, {0: "d"}, {0: "e", 1: "f", 2: "g"}), + ) + _testing.assert_onnx_program(onnx_program) + def test_bitwise_and_scalar(self): class Model(torch.nn.Module): def forward(self, x): From 1e300970d927f129e4a3a5b8558f8592765e4c3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 20 Oct 2025 15:19:54 +0200 Subject: [PATCH 10/15] raise an exception --- .../function_libs/torch_lib/ops/core.py | 25 ++++++++++++------- .../function_libs/torch_lib/e2e_ops_tests.py | 1 + .../function_libs/torch_lib/ops_test_data.py | 2 +- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 130de462ab..c659fd6e89 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4276,15 +4276,22 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): reshape_update = self.shape[i] else: idx = indices[i] - reshape_update = math.prod(idx.shape) - # when Index is more than 1D, flatten it and also the values shape - # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) - # Indices -> (2*4,) and values shape (2*4, 32) - if len(idx.shape) > 1: - values_shape = (reshape_update, *values_shape[len(idx.shape) :]) - - # Flatten index (always working with 1D index in each dim) - idx = op.Reshape(idx, [-1]) + if all(isinstance(s, int) for s in idx.shape): + reshape_update = math.prod(idx.shape) + # when Index is more than 1D, flatten it and also the values shape + # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) + # Indices -> (2*4,) and values shape (2*4, 32) + if len(idx.shape) > 1: + values_shape = (reshape_update, *values_shape[len(idx.shape) :]) + + # Flatten index (always working with 1D index in each dim) + idx = op.Reshape(idx, [-1]) + else: + raise RuntimeError( + f"Unable to handle index {indices[i]} for axis={i} " + f"because one of the dimension is not static as shape=" + f"{idx.shape}, indices={indices}" + ) # Create a reshape pattern: one value per index dimension, # with the current dimension set to the update size. diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 26629cd0fa..a0724f6ff6 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -269,6 +269,7 @@ def forward(self, update, index1, index2): ) _testing.assert_onnx_program(onnx_program) + def test_index_put_scatter_nd(self): class Model(torch.nn.Module): def forward(self, x, index, update): diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b60fd8cf31..4fb2446b20 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -313,7 +313,7 @@ def _im2col_input_wrangler( def _index_put_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: - args[1] = [np.array(elem) for elem in args[1]] + args[1] = [(elem.detach().cpu().numpy() if hasattr(elem, "detach") else np.array(elem)) for elem in args[1]] return args, kwargs From f3731edfe75717b79f3b2dbce0d198c57c7f6f35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 20 Oct 2025 16:33:53 +0200 Subject: [PATCH 11/15] one more unittest --- onnxscript/function_libs/torch_lib/ops/core.py | 13 ++++++++++++- tests/function_libs/torch_lib/e2e_ops_tests.py | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index c659fd6e89..391f2ce178 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4241,6 +4241,11 @@ def aten_index_put( ): return _aten_index_put_scatter_nd(self, indices, values, accumulate) + if len(indices) == 1 and set(indices[0].shape[:-1]) == {1} and indices[0].shape[0] == 1: + # shape(self) = (5,5), shape(indices[0]) = (1,2), shape(values) = (2,5) + # This case was only found in ops_data test. + return _aten_index_put_scatter_nd(self, [op.Reshape(indices[0], [-1])], values, accumulate) + def _make_reshape_list_broadcastable(reshape_list, values_shape): # Remove ones until the rank of reshape_list matches values_shape. while len(reshape_list) > len(values_shape) and 1 in reshape_list: @@ -4252,7 +4257,13 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): # the reshape list should be : [[2, 1], [1, 3], [2, 1]] for i, r in enumerate(reshape_list): if r not in (1, values_shape[i]): - value_index = values_shape.index(r) + try: + value_index = values_shape.index(r) + except ValueError as e: + raise RuntimeError( + f"Unable to find element {r!r} in shape {values_shape}, " + f"reshape_list={reshape_list}" + ) from e # Swap elements # For the example above the current reshape list is [1, 2] for last dim, # to make it broadcastable, we swap the elements diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index a0724f6ff6..2193830837 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -269,6 +269,23 @@ def forward(self, update, index1, index2): ) _testing.assert_onnx_program(onnx_program) + def test_index_put_55_12_25(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + return torch.ops.aten.index_put(x, [index], update) + + x = torch.zeros((6, 5), dtype=torch.float32) + index = torch.tensor([[2, 1]], dtype=torch.int64) + update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, index, update), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) def test_index_put_scatter_nd(self): class Model(torch.nn.Module): From b6ebed8e7709e37710b8a7942bbf6ee63a35fc1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 20 Oct 2025 17:04:47 +0200 Subject: [PATCH 12/15] fix reduction --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- tests/function_libs/torch_lib/e2e_ops_tests.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 391f2ce178..7d3b1feabb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4345,7 +4345,7 @@ def _1dint(i: int): assert len(n_none) == 1, f"Unable to handle that case: n_none={n_none}" unsq = op.Unsqueeze(indices[n_none[0]], _1dint(1)) if n_none[0] == 0: - return op.ScatterND(x, unsq, values) + return op.ScatterND(x, unsq, values, reduction="add" if accumulate else "none") perm = list(range(len(x.shape))) perm[n_none[0]], perm[0] = perm[0], perm[n_none[0]] diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 2193830837..decdbb2bbd 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -287,6 +287,24 @@ def forward(self, x, index, update): ) _testing.assert_onnx_program(onnx_program) + def test_index_put_55_2_25(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + return torch.ops.aten.index_put(x, [index], update, accumulate=True) + + x = torch.ones((6, 5), dtype=torch.float32) + index = torch.tensor([4, 3], dtype=torch.int64) + update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, index, update), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_index_put_scatter_nd(self): class Model(torch.nn.Module): def forward(self, x, index, update): From 1c1fc1ac92e9178dd17a5f8a4c62c3f1ed55f8a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 20 Oct 2025 17:22:37 +0200 Subject: [PATCH 13/15] lint --- onnxscript/function_libs/torch_lib/ops/core.py | 4 +++- tests/function_libs/torch_lib/ops_test_data.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 7d3b1feabb..264405aac6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4244,7 +4244,9 @@ def aten_index_put( if len(indices) == 1 and set(indices[0].shape[:-1]) == {1} and indices[0].shape[0] == 1: # shape(self) = (5,5), shape(indices[0]) = (1,2), shape(values) = (2,5) # This case was only found in ops_data test. - return _aten_index_put_scatter_nd(self, [op.Reshape(indices[0], [-1])], values, accumulate) + return _aten_index_put_scatter_nd( + self, [op.Reshape(indices[0], [-1])], values, accumulate + ) def _make_reshape_list_broadcastable(reshape_list, values_shape): # Remove ones until the rank of reshape_list matches values_shape. diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4fb2446b20..950805ac68 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -313,7 +313,10 @@ def _im2col_input_wrangler( def _index_put_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: - args[1] = [(elem.detach().cpu().numpy() if hasattr(elem, "detach") else np.array(elem)) for elem in args[1]] + args[1] = [ + (elem.detach().cpu().numpy() if hasattr(elem, "detach") else np.array(elem)) + for elem in args[1] + ] return args, kwargs From 37a861fad39b2865014993b977bafc9640a5c5e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 28 Oct 2025 10:19:07 +0100 Subject: [PATCH 14/15] rename some variables --- .../function_libs/torch_lib/ops/core.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 264405aac6..b5f0b94557 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4233,10 +4233,10 @@ def aten_index_put( ): return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate) - n_none = [i for i, ind in enumerate(indices) if ind is not None] + not_none = [i for i, ind in enumerate(indices) if ind is not None] if ( - len(n_none) == 1 - and len(indices[n_none[0]].shape) == 1 + len(not_none) == 1 + and len(indices[not_none[0]].shape) == 1 and len(self.shape) == len(values.shape) ): return _aten_index_put_scatter_nd(self, indices, values, accumulate) @@ -4343,14 +4343,14 @@ def _aten_index_put_scatter_nd( def _1dint(i: int): return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i])) - n_none = [i for i, ind in enumerate(indices) if ind is not None] - assert len(n_none) == 1, f"Unable to handle that case: n_none={n_none}" - unsq = op.Unsqueeze(indices[n_none[0]], _1dint(1)) - if n_none[0] == 0: + not_none = [i for i, ind in enumerate(indices) if ind is not None] + assert len(not_none) == 1, f"Unable to handle that case: not_none={not_none}" + unsq = op.Unsqueeze(indices[not_none[0]], _1dint(1)) + if not_none[0] == 0: return op.ScatterND(x, unsq, values, reduction="add" if accumulate else "none") perm = list(range(len(x.shape))) - perm[n_none[0]], perm[0] = perm[0], perm[n_none[0]] + perm[not_none[0]], perm[0] = perm[0], perm[not_none[0]] return op.Transpose( op.ScatterND( op.Transpose(x, perm=perm), @@ -4421,21 +4421,21 @@ def _mkstride(x, i): return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1) shape = [1] * (len(x.shape) + 1) - r_fixed = [] + reshaped_fixed = [] if fixed: new_shape = shape.copy() new_shape[-1] = -1 - r_fixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed] + reshaped_fixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed] - r_exped = [] + reshaped_exped = [] for i, e in exped: new_shape = shape.copy() new_shape[i] = -1 - r_exped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape)) + reshaped_exped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape)) # final sum unflat = None - for a in [*r_fixed, *r_exped]: + for a in [*reshaped_fixed, *reshaped_exped]: if unflat is None: unflat = a continue From ce6ba0b24af669cceff4dae5eda84ad05c9e46bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 9 Nov 2025 15:13:20 +0100 Subject: [PATCH 15/15] better error message --- onnxscript/function_libs/torch_lib/ops/nn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4f81cc7907..19e1771871 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1840,7 +1840,8 @@ def aten_scaled_dot_product_attention( key, value = _attention_repeat_kv_for_group_query(query, key, value) else: assert query.shape[1] == key.shape[1] == value.shape[1], ( - "SDPA (MHA) requires q_num_heads = kv_num_heads" + "SDPA (MHA) requires q_num_heads = kv_num_heads, " + f"query.shape={query.shape}, key.shape{key.shape}, value.shape={value.shape}" ) if attn_mask is None: