Skip to content

Commit 1dd9d04

Browse files
xadupregithub-code-quality[bot]gramalingamgithub-advanced-security[bot]
authored
Add converter for unique_consecutive (#2694)
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> Co-authored-by: G. Ramalingam <grama@microsoft.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
1 parent 10e541e commit 1dd9d04

File tree

2 files changed

+90
-2
lines changed

2 files changed

+90
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from onnxscript.onnx_opset import opset18 as op
5252
from onnxscript.onnx_types import TensorType
5353

54+
_INT32_MAX = 2147483647
5455
_INT64_MAX = 9223372036854775807
5556
_INT64_MIN = -9223372036854775808
5657
_MATH_PI = math.pi
@@ -9183,15 +9184,57 @@ def aten_unfold_copy(self: TensorType, dimension: int, size: int, step: int) ->
91839184
raise NotImplementedError()
91849185

91859186

9187+
@torch_op("aten::unique_consecutive", trace_only=True)
91869188
def aten_unique_consecutive(
9187-
self: TensorType,
9189+
x: TensorType,
91889190
return_inverse: bool = False,
91899191
return_counts: bool = False,
91909192
dim: Optional[int] = None,
91919193
) -> tuple[TensorType, TensorType, TensorType]:
91929194
"""unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)"""
9195+
assert x.dtype in {INT64.dtype, INT32.dtype}, (
9196+
"unique_consecutive not implemented for other type than int32, int64"
9197+
)
9198+
rank_x = len(x.shape)
91939199

9194-
raise NotImplementedError()
9200+
zero = op.Constant(value=ir.tensor([0], dtype=x.dtype))
9201+
zero64 = op.Constant(value=ir.tensor([0], dtype=INT64.dtype))
9202+
minus_one = op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))
9203+
9204+
if dim is None:
9205+
if rank_x != 1:
9206+
x = op.Reshape(x, minus_one)
9207+
else:
9208+
assert rank_x == 1 and dim == 0, (
9209+
f"Not implemented for x={x!r} with rank={rank_x} and dim={dim}."
9210+
)
9211+
9212+
lag = op.Concat(
9213+
# Hopefully this will never be equal to the first value of the tensor x
9214+
# ideally we could do differently but with a higher cost
9215+
op.Constant(value=ir.tensor([_INT32_MAX], dtype=x.dtype)),
9216+
op.Slice(x, zero64, minus_one, zero64),
9217+
axis=0,
9218+
)
9219+
eq = op.Equal(x, lag)
9220+
diff = op.Not(eq)
9221+
res = op.Compress(x, diff, axis=0)
9222+
9223+
zero_no_dim = op.Constant(value=ir.tensor(0, dtype=x.dtype))
9224+
one_no_dim = op.Constant(value=ir.tensor(1, dtype=x.dtype))
9225+
one = op.Constant(value=ir.tensor([1], dtype=x.dtype))
9226+
9227+
inverse = op.Sub(op.CumSum(op.Cast(diff, to=x.dtype), zero), one)
9228+
shape_x = op.Shape(x)
9229+
indices = op.Range(zero_no_dim, op.Squeeze(shape_x), one_no_dim)
9230+
points = op.Compress(indices, diff, axis=0)
9231+
lagp = op.Concat(
9232+
op.Slice(points, one, op.Shape(points), zero),
9233+
shape_x,
9234+
axis=0,
9235+
)
9236+
counts = op.Sub(lagp, points)
9237+
return res, inverse, counts
91959238

91969239

91979240
@torch_op("aten::_unique", trace_only=True)

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,51 @@ def forward(self, x):
406406
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
407407
_testing.assert_onnx_program(onnx_program)
408408

409+
def test_aten_unique_consecutive(self):
410+
class Model(torch.nn.Module):
411+
def forward(self, x):
412+
return torch.unique_consecutive(x)
413+
414+
model = Model()
415+
x = torch.tensor([0, 1, 2, 2, 3, 3, 0, 0], dtype=torch.int64)
416+
onnx_program = torch.onnx.export(
417+
model,
418+
(x,),
419+
dynamic_shapes=({0: "length"},),
420+
dynamo=True,
421+
)
422+
_testing.assert_onnx_program(onnx_program)
423+
424+
def test_aten_unique_consecutive_int32(self):
425+
class Model(torch.nn.Module):
426+
def forward(self, x):
427+
return torch.unique_consecutive(x)
428+
429+
model = Model()
430+
x = torch.tensor([0, 1, 2, 2, 3, 3, 0, 0], dtype=torch.int32)
431+
onnx_program = torch.onnx.export(
432+
model,
433+
(x,),
434+
dynamic_shapes=({0: "length"},),
435+
dynamo=True,
436+
)
437+
_testing.assert_onnx_program(onnx_program)
438+
439+
def test_aten_unique_consecutive_return(self):
440+
class Model(torch.nn.Module):
441+
def forward(self, x):
442+
return torch.unique_consecutive(x, return_inverse=True, return_counts=True)
443+
444+
model = Model()
445+
x = torch.tensor([0, 1, 2, 2, 3, 3, 3, 0, 0], dtype=torch.int64)
446+
onnx_program = torch.onnx.export(
447+
model,
448+
(x,),
449+
dynamic_shapes=({0: "length"},),
450+
dynamo=True,
451+
)
452+
_testing.assert_onnx_program(onnx_program)
453+
409454
def test_aten_stft_1(self):
410455
class Model(torch.nn.Module):
411456
def forward(self, x):

0 commit comments

Comments
 (0)