Skip to content

Commit 7b04774

Browse files
authored
[torchlib] Modify aten_unbind to use None for split_sizes (#2536)
According to https://onnx.ai/onnx/operators/onnx__SplitToSequence.html#summary, `If the argument split is not specified, a default scalar value of 1 is used as the value of split`, and this is the only case when `keepdims` can be set to `0`. Fixes #2533
1 parent 8974f5e commit 7b04774

File tree

1 file changed

+2
-3
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+2
-3
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8718,12 +8718,11 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
87188718
return op.CastLike(self, other)
87198719

87208720

8721-
@torch_op("aten::unbind.int")
8721+
@torch_op("aten::unbind.int", trace_only=True)
87228722
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
87238723
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""
87248724

8725-
split_sizes = op.Constant(value_int=1)
8726-
return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False)
8725+
return op.SplitToSequence(self, axis=dim, keepdims=False)
87278726

87288727

87298728
@torch_op("aten::unflatten.int", trace_only=True)

0 commit comments

Comments
 (0)