Skip to content

Commit 45b5189

Browse files
authored
[torchlib] Fix concat when input tensor has shape (0,) (#2661)
Filter out size-0 tensors. When there is only one input, create an identity op instead of a concat op. Fix #2660 Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 69025f7 commit 45b5189

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,7 +1521,7 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType:
15211521
raise NotImplementedError()
15221522

15231523

1524-
@torch_op("aten::cat", trace_only=True, complex=True)
1524+
@torch_op(("aten::cat", "aten::concat", "aten::concatenate"), trace_only=True, complex=True)
15251525
def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
15261526
"""cat(Tensor[] tensors, int dim=0) -> Tensor"""
15271527
# Real representation unsqueezes the last dimension
@@ -1534,8 +1534,18 @@ def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
15341534
def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
15351535
"""cat(Tensor[] tensors, int dim=0) -> Tensor"""
15361536

1537-
# Remove None tensors
1538-
tensors = [tensor for tensor in tensors if tensor is not None]
1537+
filtered_tensors = []
1538+
for tensor in tensors:
1539+
# Remove None tensors
1540+
if tensor is None:
1541+
continue
1542+
# Remove empty tensors
1543+
if tensor.shape == (0,):
1544+
continue
1545+
filtered_tensors.append(tensor)
1546+
assert filtered_tensors, "aten::cat received all None or empty tensors"
1547+
if len(filtered_tensors) == 1:
1548+
return op.Identity(filtered_tensors[0])
15391549
return op.Concat(*tensors, axis=dim)
15401550

15411551

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,32 @@ def forward(self, x2d, x3d, x4d, x5d):
276276
)
277277
_testing.assert_onnx_program(onnx_program)
278278

279+
def test_concat_with_empty_tensor(self):
280+
class Model(torch.nn.Module):
281+
def forward(self, x):
282+
return torch.cat([x, torch.tensor([]), x], dim=0)
283+
284+
onnx_program = torch.onnx.export(
285+
Model(),
286+
(torch.tensor([1, 2]),),
287+
dynamo=True,
288+
verbose=False,
289+
)
290+
_testing.assert_onnx_program(onnx_program)
291+
292+
def test_concat_with_empty_tensor_single_element(self):
293+
class Model(torch.nn.Module):
294+
def forward(self, x):
295+
return torch.cat([x, torch.tensor([])], dim=1)
296+
297+
onnx_program = torch.onnx.export(
298+
Model(),
299+
(torch.tensor([[1, 2]]),),
300+
dynamo=True,
301+
verbose=False,
302+
)
303+
_testing.assert_onnx_program(onnx_program)
304+
279305

280306
if __name__ == "__main__":
281307
unittest.main()

0 commit comments

Comments
 (0)