Skip to content

Commit 3156bed

Browse files
authored
[torchlib] Fix aten_gather to correctly handle scalar indices (#2566)
Fixes #2564 Signed-off-by: Linsho Kaku <linsho@preferred.jp>
1 parent f529292 commit 3156bed

File tree

1 file changed

+6
-2
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+6
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3814,11 +3814,15 @@ def aten_gather(
38143814
else:
38153815
return op.Expand(self, op.Shape(index))
38163816

3817-
if len(index.shape) == 0:
3818-
return op.Identity(self)
3817+
is_scalar_index = len(index.shape) == 0
3818+
if is_scalar_index:
3819+
index = op.Unsqueeze(index, [0])
38193820

38203821
index = op.Cast(index, to=INT64.dtype)
38213822
result = op.GatherElements(self, index, axis=dim)
3823+
3824+
if is_scalar_index:
3825+
result = op.Squeeze(result, [0])
38223826
return result
38233827

38243828

0 commit comments

Comments
 (0)