|
51 | 51 | from onnxscript.onnx_opset import opset18 as op |
52 | 52 | from onnxscript.onnx_types import TensorType |
53 | 53 |
|
| 54 | +_INT32_MAX = 2147483647 |
54 | 55 | _INT64_MAX = 9223372036854775807 |
55 | 56 | _INT64_MIN = -9223372036854775808 |
56 | 57 | _MATH_PI = math.pi |
@@ -9183,15 +9184,57 @@ def aten_unfold_copy(self: TensorType, dimension: int, size: int, step: int) -> |
9183 | 9184 | raise NotImplementedError() |
9184 | 9185 |
|
9185 | 9186 |
|
| 9187 | +@torch_op("aten::unique_consecutive", trace_only=True) |
9186 | 9188 | def aten_unique_consecutive( |
9187 | | - self: TensorType, |
| 9189 | + x: TensorType, |
9188 | 9190 | return_inverse: bool = False, |
9189 | 9191 | return_counts: bool = False, |
9190 | 9192 | dim: Optional[int] = None, |
9191 | 9193 | ) -> tuple[TensorType, TensorType, TensorType]: |
9192 | 9194 | """unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)""" |
| 9195 | + assert x.dtype in {INT64.dtype, INT32.dtype}, ( |
| 9196 | + "unique_consecutive not implemented for other type than int32, int64" |
| 9197 | + ) |
| 9198 | + rank_x = len(x.shape) |
9193 | 9199 |
|
9194 | | - raise NotImplementedError() |
| 9200 | + zero = op.Constant(value=ir.tensor([0], dtype=x.dtype)) |
| 9201 | + zero64 = op.Constant(value=ir.tensor([0], dtype=INT64.dtype)) |
| 9202 | + minus_one = op.Constant(value=ir.tensor([-1], dtype=INT64.dtype)) |
| 9203 | + |
| 9204 | + if dim is None: |
| 9205 | + if rank_x != 1: |
| 9206 | + x = op.Reshape(x, minus_one) |
| 9207 | + else: |
| 9208 | + assert rank_x == 1 and dim == 0, ( |
| 9209 | + f"Not implemented for x={x!r} with rank={rank_x} and dim={dim}." |
| 9210 | + ) |
| 9211 | + |
| 9212 | + lag = op.Concat( |
| 9213 | + # Hopefully this will never be equal to the first value of the tensor x |
| 9214 | + # ideally we could do differently but with a higher cost |
| 9215 | + op.Constant(value=ir.tensor([_INT32_MAX], dtype=x.dtype)), |
| 9216 | + op.Slice(x, zero64, minus_one, zero64), |
| 9217 | + axis=0, |
| 9218 | + ) |
| 9219 | + eq = op.Equal(x, lag) |
| 9220 | + diff = op.Not(eq) |
| 9221 | + res = op.Compress(x, diff, axis=0) |
| 9222 | + |
| 9223 | + zero_no_dim = op.Constant(value=ir.tensor(0, dtype=x.dtype)) |
| 9224 | + one_no_dim = op.Constant(value=ir.tensor(1, dtype=x.dtype)) |
| 9225 | + one = op.Constant(value=ir.tensor([1], dtype=x.dtype)) |
| 9226 | + |
| 9227 | + inverse = op.Sub(op.CumSum(op.Cast(diff, to=x.dtype), zero), one) |
| 9228 | + shape_x = op.Shape(x) |
| 9229 | + indices = op.Range(zero_no_dim, op.Squeeze(shape_x), one_no_dim) |
| 9230 | + points = op.Compress(indices, diff, axis=0) |
| 9231 | + lagp = op.Concat( |
| 9232 | + op.Slice(points, one, op.Shape(points), zero), |
| 9233 | + shape_x, |
| 9234 | + axis=0, |
| 9235 | + ) |
| 9236 | + counts = op.Sub(lagp, points) |
| 9237 | + return res, inverse, counts |
9195 | 9238 |
|
9196 | 9239 |
|
9197 | 9240 | @torch_op("aten::_unique", trace_only=True) |
|
0 commit comments