Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,5 @@

# Assume there are 2 processes (2 devices)
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
print("before \n")
result = pipe(prompt).images[0]
print("after ")
result.save(f"result_{distributed_state.process_index}.png")
44 changes: 40 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import operator
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -217,21 +217,57 @@ def aten_ops_native_group_norm(
)


@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
def parse_cat_args(
args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Tuple[List[Any], int]:
"""
Process inputs for torch.ops.aten.cat.default.

Handles these valid patterns:
1. args = ((t1, t2, ...), dim)
2. args = ((t1, t2, ...),), kwargs = {dim: X} with optional dim in kwargs

Returns:
(input_tensors, dim)
input_tensors: tuple of tensor arguments
dim: integer concatenation dimension (default 0)
"""

if len(args) > 1 and isinstance(args[0], (list, tuple)):
input_tensors = list(args[0])
dim = args_bounds_check(args, 1, 0)

else:
# If single arg is itself a tuple/list, unwrap it
if len(args) == 1 and isinstance(args[0], (list, tuple)):
input_tensors = list(args[0])
else:
input_tensors = list(args)

dim = kwargs.get("dim", 0)

return input_tensors, dim


@dynamo_tensorrt_converter(
torch.ops.aten.cat.default,
supports_dynamic_shapes=True,
)
def aten_ops_cat(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
inputs, dim = parse_cat_args(args, kwargs)
return impl.cat.cat(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
dim=args_bounds_check(args, 1, 0),
input=inputs,
dim=dim,
)


Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/cat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Optional, Sequence, Union

import numpy as np
Expand All @@ -15,6 +16,8 @@
set_layer_name,
)

logger = logging.getLogger(__name__)


def cat(
ctx: ConversionContext,
Expand Down
35 changes: 35 additions & 0 deletions tests/py/dynamo/conversion/test_cat_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,41 @@ def forward(self, x, y, z):
inputs,
)

@parameterized.expand(
[
("pos", 1),
("neg", -2),
]
)
def test_cat_dim_in_kwargs(self, _, dim):
class Cat(nn.Module):
def forward(self, x, y, z):
return torch.ops.aten.cat.default((x, y, z), dim=dim)

inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
self.run_test(
Cat(),
inputs,
)

@parameterized.expand(
[
("pos", 0),
("neg", -3),
]
)
def test_cat_with_scalar_inputs(self, _, dim):
# Ensure scalar tensor wrap works
class Cat(nn.Module):
def forward(self, x, y):
# y is a scalar, x is a tensor
return torch.ops.aten.cat.default((x, y), dim)

x = torch.randn(1, 2, 3, device="cuda")
y = torch.ones_like(x) * 5.0 # simulate scalar broadcast
inputs = [x, y]
self.run_test(Cat(), inputs)

@parameterized.expand(
[
("pos", 1),
Expand Down
Loading