-
Notifications
You must be signed in to change notification settings - Fork 89
Improves implementation of aten_index_put #2641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3cb493f
c9472f3
434fcfb
5cf4882
ae6adca
8510364
e4d574a
86d482d
8305cac
02dda0e
e6f7633
e108dc3
988e9f6
f4a1196
f56ccc7
1e30097
f3731ed
b6ebed8
1c1fc1a
37a861f
d721b3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4233,6 +4233,30 @@ def aten_index_put( | |
| See implementation of `torch.onnx.symbolic_opset11.index_put | ||
| <https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_. | ||
| """ | ||
| if ( | ||
| len(indices) > 1 | ||
| and any( | ||
| isinstance(index, torch.onnx._internal.exporter._tensors.SymbolicTensor) # pylint: disable=protected-access | ||
| for index in indices | ||
| ) | ||
| and len(values.shape) == 1 | ||
| ): | ||
| return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate) | ||
|
|
||
| not_none = [i for i, ind in enumerate(indices) if ind is not None] | ||
| if ( | ||
| len(not_none) == 1 | ||
| and len(indices[not_none[0]].shape) == 1 | ||
| and len(self.shape) == len(values.shape) | ||
| ): | ||
| return _aten_index_put_scatter_nd(self, indices, values, accumulate) | ||
|
|
||
| if len(indices) == 1 and set(indices[0].shape[:-1]) == {1} and indices[0].shape[0] == 1: | ||
| # shape(self) = (5,5), shape(indices[0]) = (1,2), shape(values) = (2,5) | ||
| # This case was only found in ops_data test. | ||
| return _aten_index_put_scatter_nd( | ||
| self, [op.Reshape(indices[0], [-1])], values, accumulate | ||
| ) | ||
|
|
||
| def _make_reshape_list_broadcastable(reshape_list, values_shape): | ||
| # Remove ones until the rank of reshape_list matches values_shape. | ||
|
|
@@ -4245,7 +4269,13 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): | |
| # the reshape list should be : [[2, 1], [1, 3], [2, 1]] | ||
| for i, r in enumerate(reshape_list): | ||
| if r not in (1, values_shape[i]): | ||
| value_index = values_shape.index(r) | ||
| try: | ||
| value_index = values_shape.index(r) | ||
| except ValueError as e: | ||
| raise RuntimeError( | ||
| f"Unable to find element {r!r} in shape {values_shape}, " | ||
| f"reshape_list={reshape_list}" | ||
| ) from e | ||
| # Swap elements | ||
| # For the example above the current reshape list is [1, 2] for last dim, | ||
| # to make it broadcastable, we swap the elements | ||
|
|
@@ -4269,15 +4299,22 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): | |
| reshape_update = self.shape[i] | ||
| else: | ||
| idx = indices[i] | ||
| reshape_update = math.prod(idx.shape) | ||
| # when Index is more than 1D, flatten it and also the values shape | ||
| # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) | ||
| # Indices -> (2*4,) and values shape (2*4, 32) | ||
| if len(idx.shape) > 1: | ||
| values_shape = (reshape_update, *values_shape[len(idx.shape) :]) | ||
|
|
||
| # Flatten index (always working with 1D index in each dim) | ||
| idx = op.Reshape(idx, [-1]) | ||
| if all(isinstance(s, int) for s in idx.shape): | ||
| reshape_update = math.prod(idx.shape) | ||
| # when Index is more than 1D, flatten it and also the values shape | ||
| # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) | ||
| # Indices -> (2*4,) and values shape (2*4, 32) | ||
| if len(idx.shape) > 1: | ||
| values_shape = (reshape_update, *values_shape[len(idx.shape) :]) | ||
|
|
||
| # Flatten index (always working with 1D index in each dim) | ||
| idx = op.Reshape(idx, [-1]) | ||
| else: | ||
| raise RuntimeError( | ||
| f"Unable to handle index {indices[i]} for axis={i} " | ||
| f"because one of the dimension is not static as shape=" | ||
| f"{idx.shape}, indices={indices}" | ||
| ) | ||
|
|
||
| # Create a reshape pattern: one value per index dimension, | ||
| # with the current dimension set to the update size. | ||
|
|
@@ -4302,14 +4339,131 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): | |
| # Flatten values to match the indices | ||
| flat_values = op.Reshape(values, [-1]) | ||
|
|
||
| if accumulate: | ||
| result = op.ScatterND(self, new_index, flat_values, reduction="add") | ||
| else: | ||
| result = op.ScatterND(self, new_index, flat_values) | ||
|
|
||
| scatter_kwargs = dict(reduction="add") if accumulate else {} | ||
| result = op.ScatterND(self, new_index, flat_values, **scatter_kwargs) | ||
| return result | ||
|
|
||
|
|
||
| def _aten_index_put_scatter_nd( | ||
| x: TReal, | ||
| indices: Sequence[INT64], | ||
| values: TReal, | ||
| accumulate: bool = False, | ||
| ) -> TReal: | ||
| def _1dint(i: int): | ||
| return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i])) | ||
|
|
||
| not_none = [i for i, ind in enumerate(indices) if ind is not None] | ||
| assert len(not_none) == 1, f"Unable to handle that case: not_none={not_none}" | ||
| unsq = op.Unsqueeze(indices[not_none[0]], _1dint(1)) | ||
| if not_none[0] == 0: | ||
| return op.ScatterND(x, unsq, values, reduction="add" if accumulate else "none") | ||
|
|
||
| perm = list(range(len(x.shape))) | ||
| perm[not_none[0]], perm[0] = perm[0], perm[not_none[0]] | ||
| return op.Transpose( | ||
| op.ScatterND( | ||
| op.Transpose(x, perm=perm), | ||
| unsq, | ||
| op.Transpose(values, perm=perm), | ||
| reduction="add" if accumulate else "none", | ||
| ), | ||
| perm=perm, | ||
| ) | ||
|
|
||
|
|
||
| def _aten_index_put_dynamic( | ||
| x: TReal, | ||
| indices: Sequence[INT64], | ||
| values: TReal, | ||
| accumulate: bool = False, | ||
| ) -> TReal: | ||
| def _1dint(i: int): | ||
| return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i])) | ||
|
|
||
| def _0dint(i: int): | ||
| return op.Constant(value_int=ir.AttrInt64("value_int", i)) | ||
|
|
||
| def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): | ||
| if ind is not None: | ||
| return op.Cast(ind, to=INT64.dtype), False | ||
| return ( | ||
| op.Cast( | ||
| op.Range( # Range does not return a typed result | ||
| _0dint(0), | ||
| op.Squeeze(op.Shape(x, start=dim, end=dim + 1)), | ||
| _0dint(1), | ||
| ), | ||
| to=INT64.dtype, | ||
| ), | ||
| True, | ||
| ) | ||
|
|
||
| shape_x = op.Shape(x) | ||
| exped = [] | ||
| fixed = [] | ||
| reshape_value_shape2 = [] | ||
| expand_value_shape = [] | ||
| for i, ind in enumerate(indices): | ||
| if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): # pylint: disable=protected-access | ||
| ind.dtype = ir.DataType.INT64 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need the above line? Just wondering ... shouldn't it already have dtype set?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it is not useful anymore but when I did this PR, it was needed. |
||
| ind, expanded = _make_range_or_cast(ind, shape_x, False, i) | ||
| if expanded: | ||
| exped.append((i, ind)) | ||
| expand_value_shape.append(op.Shape(x, start=i, end=i + 1)) | ||
| reshape_value_shape2.append(_1dint(1)) | ||
| else: | ||
| expand_value_shape.append(_1dint(1)) | ||
| reshape_value_shape2.append(op.Shape(ind)) | ||
| fixed.append((i, ind)) | ||
|
|
||
| reshape_value_shape1 = [_1dint(1)] * len(indices) | ||
| if len(fixed) <= 1: | ||
| reshape_value_shape1 = None | ||
| elif fixed: | ||
| reshape_value_shape1[fixed[-1][0]] = _1dint(-1) | ||
|
|
||
| def _mkstride(x, i): | ||
| if i >= len(x.shape) - 1: | ||
| return _1dint(1) | ||
| if i == len(x.shape) - 2: | ||
| return op.Shape(x, start=i + 1) | ||
| return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1) | ||
|
|
||
| shape = [1] * (len(x.shape) + 1) | ||
| reshaped_fixed = [] | ||
| if fixed: | ||
| new_shape = shape.copy() | ||
| new_shape[-1] = -1 | ||
| reshaped_fixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed] | ||
|
|
||
| reshaped_exped = [] | ||
| for i, e in exped: | ||
| new_shape = shape.copy() | ||
| new_shape[i] = -1 | ||
| reshaped_exped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape)) | ||
|
|
||
| # final sum | ||
| unflat = None | ||
| for a in [*reshaped_fixed, *reshaped_exped]: | ||
| if unflat is None: | ||
| unflat = a | ||
| continue | ||
| unflat = op.Add(unflat, a) | ||
|
|
||
| # value_shape | ||
| expanded_values = values | ||
| if reshape_value_shape1 is not None: | ||
| expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0)) | ||
| expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0)) | ||
| flat_ind = op.Reshape(unflat, _1dint(-1)) | ||
| expanded_values = op.Reshape(expanded_values, _1dint(-1)) | ||
| flat_x = op.Reshape(x, _1dint(-1)) | ||
| scat_kwargs = {"reduction": "add"} if accumulate else {} | ||
| flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs) | ||
| return op.Reshape(flat_up_x, op.Shape(x)) | ||
|
|
||
|
|
||
| @torch_op("aten::index_put", trace_only=True) | ||
| def aten_index_put_bool( | ||
| self: TReal, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.