Skip to content

Commit 366f7be

Browse files
authored
[torchlib] Fix repeat_interleave when repeats is a symbolic tensor (#2548)
1 parent 50d7e87 commit 366f7be

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7332,16 +7332,25 @@ def aten_repeat_interleave_self_int(
73327332
self_rank = len(self.shape)
73337333
pos_dim = (dim + self_rank) % self_rank
73347334
unsqueezed = op.Unsqueeze(self, [pos_dim + 1])
7335-
tiles = [1] * (self_rank + 1)
7336-
tiles[pos_dim + 1] = repeats
7337-
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
7338-
tiled = op.Tile(unsqueezed, tile_repeat)
7335+
if isinstance(repeats, int):
7336+
tiles = [1] * (self_rank + 1)
7337+
tiles[pos_dim + 1] = repeats
7338+
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
7339+
else:
7340+
# repeats is a symbolic tensor
7341+
tile_repeat = op.Concat(
7342+
op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)),
7343+
op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))),
7344+
op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)),
7345+
axis=0,
7346+
)
7347+
tiled = op.Expand(unsqueezed, tile_repeat)
73397348
if self_rank == 1:
73407349
return op.Identity(tiled)
73417350
final_shape = op.Concat(
73427351
op.Shape(self, start=0, end=dim),
73437352
op.Constant(value_ints=[-1]),
7344-
op.Shape(self, start=dim + 1),
7353+
op.Shape(self, start=pos_dim + 1),
73457354
axis=0,
73467355
)
73477356
return op.Reshape(tiled, final_shape)
@@ -7380,20 +7389,22 @@ def aten_repeat_interleave_Tensor(
73807389
if dim is None:
73817390
# flatten
73827391
self = op.Reshape(self, [-1])
7383-
rk = 1
7392+
rank = 1
73847393
else:
7385-
rk = len(self.shape)
7394+
rank = len(self.shape)
73867395

7387-
if rk > 2:
7396+
if rank > 2:
73887397
shape_x0 = op.Shape(self, start=0, end=1)
73897398
shape_x = op.Shape(self, start=1)
73907399
self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0))
7391-
elif rk == 1:
7400+
elif rank == 1:
73927401
shape_x = None
73937402
self = op.Reshape(self, [-1, 1])
73947403
else:
7395-
if rk != 2:
7396-
raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave")
7404+
if rank != 2:
7405+
raise NotImplementedError(
7406+
f"rank(self)={rank} not implemented for repeat_interleave"
7407+
)
73977408
shape_x = None
73987409

73997410
ci = op.CumSum(repeats, [0])
@@ -7406,7 +7417,7 @@ def aten_repeat_interleave_Tensor(
74067417
)
74077418
indices = op.Reshape(srows, [-1])
74087419
values = op.GatherND(self, op.Unsqueeze(indices, [-1]))
7409-
if rk == 2:
7420+
if rank == 2:
74107421
return values
74117422
# shape_x is None at this stage.
74127423
assert shape_x is None # for mypy

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,27 @@ def forward(self, x, ind):
137137
)
138138
_testing.assert_onnx_program(onnx_program)
139139

140+
def test_repeat_interleave_symbolic_tensor(self):
141+
class Model(torch.nn.Module):
142+
def forward(self, x, y):
143+
return torch.repeat_interleave(x, y.shape[1], dim=1) * torch.repeat_interleave(
144+
y, x.shape[1], dim=1
145+
)
146+
147+
inputs = (
148+
torch.arange(4, dtype=torch.float32).reshape((2, 2)),
149+
torch.arange(6, dtype=torch.float32).reshape((2, 3)),
150+
)
151+
onnx_program = torch.onnx.export(
152+
Model(),
153+
inputs,
154+
input_names=["x", "y"],
155+
output_names=["output"],
156+
opset_version=18,
157+
dynamo=True,
158+
)
159+
_testing.assert_onnx_program(onnx_program)
160+
140161
def test_sdpa_with_bool_attn_mask(self):
141162
class ScaledDotProductAttention(torch.nn.Module):
142163
def forward(self, query, key, value, attn_mask):

0 commit comments

Comments
 (0)