Skip to content

Commit 8974f5e

Browse files
xaduprejustinchuby
andauthored
Implements repeat_interleave (#2477)
Similar to #2464. Does not support all the cases but we can add them in other PRs. --------- Signed-off-by: xadupre <xadupre@microsoft.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 07f3e4c commit 8974f5e

File tree

3 files changed

+201
-4
lines changed

3 files changed

+201
-4
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7292,12 +7292,114 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor:
72927292
return op.Tile(self_expanded, repeats)
72937293

72947294

7295-
def aten_repeat_interleave(
7296-
repeats: TensorType, output_size: Optional[int] = None
7295+
@torch_op("aten::repeat_interleave.self_int", trace_only=True)
7296+
def aten_repeat_interleave_self_int(
7297+
self: TensorType, repeats: int, dim: Optional[int] = None
72977298
) -> TensorType:
7298-
"""repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor"""
7299+
"""repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
72997300
7300-
raise NotImplementedError()
7301+
The trick is to repeat in one direction orthogonal to reshape.
7302+
7303+
.. code-block:: python
7304+
7305+
x = torch.tensor([[0, 1, 2], [3, 4, 5]])
7306+
x.repeat_interleave(2, dim=0)
7307+
7308+
is equivalent to:
7309+
7310+
.. code-block:: python
7311+
7312+
x = torch.tensor([[0, 1, 2], [3, 4, 5]])
7313+
x.repeat((1, 2)).reshape((-1, t.shape[1]))
7314+
"""
7315+
if dim is None:
7316+
raise NotImplementedError("No conversion available yet when dim is None.")
7317+
7318+
self_rank = len(self.shape)
7319+
pos_dim = (dim + self_rank) % self_rank
7320+
unsqueezed = op.Unsqueeze(self, [pos_dim + 1])
7321+
tiles = [1] * (self_rank + 1)
7322+
tiles[pos_dim + 1] = repeats
7323+
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
7324+
tiled = op.Tile(unsqueezed, tile_repeat)
7325+
if self_rank == 1:
7326+
return op.Identity(tiled)
7327+
final_shape = op.Concat(
7328+
op.Shape(self, start=0, end=dim),
7329+
op.Constant(value_ints=[-1]),
7330+
op.Shape(self, start=dim + 1),
7331+
axis=0,
7332+
)
7333+
return op.Reshape(tiled, final_shape)
7334+
7335+
7336+
@torch_op("aten::repeat_interleave.Tensor", trace_only=True)
7337+
def aten_repeat_interleave_Tensor(
7338+
self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None
7339+
) -> TensorType:
7340+
"""repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor
7341+
7342+
When `repeats` is a tensor, each line is multiplied
7343+
by a different number.
7344+
There are multiple strategies. Here is one.
7345+
7346+
.. code-block:: python
7347+
7348+
import torch
7349+
7350+
x = torch.tensor([[0, 1, 2], [3, 4, 5]])
7351+
times = torch.tensor([2, 3], dtype=torch.int64)
7352+
y = x.repeat_interleave(times, dim=0)
7353+
print("repeat_interleave")
7354+
print(y)
7355+
7356+
ci = times.cumsum(dim=0)
7357+
rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1))
7358+
srows = times.shape[0] - rows.to(torch.int64).sum(axis=0)
7359+
indices = srows.reshape((-1, ))
7360+
print("decomposed")
7361+
print(x[indices, :])
7362+
"""
7363+
if repeats is None:
7364+
repeats = self
7365+
self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1)
7366+
if dim is None:
7367+
# flatten
7368+
self = op.Reshape(self, [-1])
7369+
rk = 1
7370+
else:
7371+
rk = len(self.shape)
7372+
7373+
if rk > 2:
7374+
shape_x0 = op.Shape(self, start=0, end=1)
7375+
shape_x = op.Shape(self, start=1)
7376+
self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0))
7377+
elif rk == 1:
7378+
shape_x = None
7379+
self = op.Reshape(self, [-1, 1])
7380+
else:
7381+
if rk != 2:
7382+
raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave")
7383+
shape_x = None
7384+
7385+
ci = op.CumSum(repeats, [0])
7386+
last_ci = op.Gather(ci, [-1])
7387+
trange = op.Range(0, op.Squeeze(last_ci, [0]), 1)
7388+
rows = op.Less(trange, op.Unsqueeze(ci, [-1]))
7389+
srows = op.Sub(
7390+
op.Shape(self, start=0, end=1),
7391+
op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]),
7392+
)
7393+
indices = op.Reshape(srows, [-1])
7394+
values = op.GatherND(self, op.Unsqueeze(indices, [-1]))
7395+
if rk == 2:
7396+
return values
7397+
# shape_x is None at this stage.
7398+
assert shape_x is None # for mypy
7399+
return op.Reshape(
7400+
values,
7401+
op.Concat([-1], shape_x, axis=0) if shape_x else [-1],
7402+
)
73017403

73027404

73037405
@torch_op("aten::reshape")

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7676
)
7777
_testing.assert_onnx_program(onnx_program)
7878

79+
def test_repeat_interleave_integer_1(self):
80+
class Model(torch.nn.Module):
81+
def forward(self, x):
82+
return torch.repeat_interleave(x, 3, dim=1)
83+
84+
onnx_program = torch.onnx.export(
85+
Model(), (torch.randn(2, 3),), dynamo=True, optimize=False
86+
)
87+
_testing.assert_onnx_program(onnx_program)
88+
89+
def test_repeat_interleave_integer_2(self):
90+
class Model(torch.nn.Module):
91+
def forward(self, x):
92+
return torch.repeat_interleave(x, 3, dim=1)
93+
94+
onnx_program = torch.onnx.export(
95+
Model(), (torch.randn(2, 3, 4),), dynamo=True, optimize=False
96+
)
97+
_testing.assert_onnx_program(onnx_program)
98+
99+
def test_repeat_interleave_tensor(self):
100+
class Model(torch.nn.Module):
101+
def forward(self, x, ind):
102+
return torch.repeat_interleave(x, ind, dim=0)
103+
104+
onnx_program = torch.onnx.export(
105+
Model(),
106+
(
107+
torch.arange(6, dtype=torch.float32).reshape((2, 3)),
108+
torch.tensor([1, 2], dtype=torch.int64),
109+
),
110+
dynamo=True,
111+
optimize=False,
112+
)
113+
_testing.assert_onnx_program(onnx_program)
114+
115+
def test_repeat_interleave_tensor_none(self):
116+
class Model(torch.nn.Module):
117+
def forward(self, x, ind):
118+
return torch.repeat_interleave(x, ind)
119+
120+
inputs = (
121+
torch.arange(4, dtype=torch.float32).reshape((2, 2)),
122+
torch.tensor([1, 2, 3, 2], dtype=torch.int64),
123+
)
124+
onnx_program = torch.onnx.export(
125+
Model(),
126+
inputs,
127+
dynamo=True,
128+
optimize=False,
129+
)
130+
onnx_program = torch.onnx.export(
131+
Model(),
132+
inputs,
133+
input_names=["x", "ind"],
134+
output_names=["output"],
135+
opset_version=18,
136+
dynamo=True,
137+
)
138+
_testing.assert_onnx_program(onnx_program)
139+
79140
def test_sdpa_with_bool_attn_mask(self):
80141
class ScaledDotProductAttention(torch.nn.Module):
81142
def forward(self, query, key, value, attn_mask):

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,40 @@ def _where_input_wrangler(
12501250
core_ops.aten_remainder,
12511251
),
12521252
TorchLibOpInfo("repeat", core_ops.aten_repeat),
1253+
TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int)
1254+
.skip(
1255+
matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int),
1256+
reason=("ignore cases when repeasts is a Tensor"),
1257+
)
1258+
.skip(
1259+
dtypes=(torch.bool,),
1260+
reason="bool not supported",
1261+
)
1262+
.skip(
1263+
matcher=lambda sample: sample.kwargs.get("dim") is None,
1264+
reason="fixme: conversion not implemented if dim is None",
1265+
)
1266+
.skip(
1267+
matcher=lambda sample: sample.input.numel() == 0,
1268+
reason="fixme: conversion not implemented when input tensor is empty",
1269+
),
1270+
TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor)
1271+
.skip(
1272+
matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int),
1273+
reason=("ignore cases when repeasts is an int"),
1274+
)
1275+
.skip(
1276+
dtypes=(torch.bool,),
1277+
reason="bool not supported",
1278+
)
1279+
.skip(
1280+
matcher=lambda sample: sample.kwargs.get("dim") is None,
1281+
reason="fixme: conversion not implemented if dim is None",
1282+
)
1283+
.skip(
1284+
matcher=lambda sample: sample.input.numel() == 0,
1285+
reason="fixme: conversion not implemented when input tensor is empty",
1286+
),
12531287
TorchLibOpInfo("reshape", core_ops.aten_reshape),
12541288
TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj),
12551289
TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg),

0 commit comments

Comments
 (0)