Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ TRTEngine::~TRTEngine() {
trt_engine_profiler.reset();
exec_ctx.reset();
cuda_engine.reset();
for (void* ptr : empty_input_ptrs) {
if (ptr)
cudaFree(ptr);
}
empty_input_ptrs.clear();
rt.reset();
}

Expand Down
3 changes: 3 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ struct TRTEngine : torch::CustomClassHolder {
bool use_pre_allocated_outputs = false;
std::vector<at::Tensor> pre_allocated_outputs;

// Empty Input Pointers
std::vector<void*> empty_input_ptrs = {};

// Output Allocator-Related Functionality
bool requires_output_allocator = false; // engine requires output allocator
bool use_output_allocator_outputs = false; // users specify to use output allocator
Expand Down
21 changes: 15 additions & 6 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ void setup_input_tensors(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
"Error while setting the tensor address for shape inputs");

void* tensor_addr = nullptr;
if (cudagraphs_enabled) {
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
compiled_engine->input_buffers[i] = input_cpu;
Expand All @@ -152,15 +153,23 @@ void setup_input_tensors(
if (cudagraphs_enabled) {
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), compiled_engine->input_buffers[i].data_ptr()),
"Error while setting the input tensor address for inputs");
tensor_addr = compiled_engine->input_buffers[i].data_ptr();
} else {
// Otherwise use the formatted buffer directly
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), formatted_inputs.back().data_ptr()),
"Error while setting the input tensor address for inputs");
tensor_addr = formatted_inputs.back().data_ptr();
}
// handle empty tensors→ TensorRT requires non-null address even if numel() = 0
size_t nbytes = final_input.numel() * final_input.element_size();
if (nbytes == 0 || tensor_addr == nullptr) {
void* dummy = nullptr;
cudaMalloc(&dummy, 1); // allocate 1 byte GPU buffer to satisfy TRT and get a non-null address
tensor_addr = dummy;
compiled_engine->empty_input_ptrs.push_back(dummy); // track to free later
}

TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), tensor_addr),
"Failed to bind tensor address for " << name);
}
}
}
Expand Down
51 changes: 51 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,57 @@ def aten_ops_native_group_norm(
)


def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool:
"""
Validator for torch.cat operation with empty tensor handling.

PyTorch allows torch.tensor([]) (shape (0,)) to be concatenated with higher-dimensional
tensors, but TensorRT requires all inputs to have the same rank. This validator catches
this specific edge case.

Example valid case: cat([(3, 4), (0, 4)], dim=0) - same rank, properly shaped empty tensor for TRT
Example invalid case: cat([(3, 4), (0,)], dim=0) - torch.tensor([]) with rank mismatch
"""
inputs = node.args[0]

if len(inputs) < 2:
return True

# Collect metadata for all inputs
input_metas = []
for inp in inputs:
if isinstance(inp, TRTTensor):
# TRTTensor has shape directly
input_metas.append(inp.shape)
else:
# For nodes, get metadata
meta = getattr(inp, "meta", {}).get("tensor_meta")
if meta is None:
# Can't validate without metadata, allow it
return True
shape = tuple(meta.shape)
input_metas.append(shape)

# Check for the specific problematic case:
# 1D empty tensor (0,) being concatenated with higher-dimensional tensors
ranks = [len(shape) for shape in input_metas]
# If all ranks are the same, it's fine (PyTorch and TensorRT both handle this)
if len(set(ranks)) == 1:
return True
# If ranks differ, check if we have a 1D empty tensor (0,) in the mix
# This is the torch.tensor([]) case that PyTorch allows but TensorRT doesn't
for i, shape in enumerate(input_metas):
if shape == (0,) or (len(shape) == 1 and shape[0] == 0):
# Found a 1D empty tensor with rank mismatch
_LOGGER.debug(
f"Concatenation rejected by TRT, torch.tensor([]) or 1D empty tensor at position {i} "
f"PyTorch allows this but TensorRT requires all inputs to have the same rank. "
f"Use torch.empty((0, ...)) with explicit dimensions matching other inputs instead. Falling back to Pytorch"
)
return False
return True


@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
def aten_ops_cat(
ctx: ConversionContext,
Expand Down
14 changes: 13 additions & 1 deletion py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,26 @@ def setup_input_tensors(
self.context.set_input_shape(
input_name, tuple(contiguous_inputs[i].shape)
)
tensor_to_bind = contiguous_inputs[i]
if tensor_to_bind.numel() == 0:
# this is used to provide valid memory address to TRT
dummy = torch.empty(
1,
dtype=tensor_to_bind.dtype,
device=torch.cuda.current_device(),
)
tensor_to_bind = dummy
if not hasattr(self, "_empty_input_buffers"):
self._empty_input_buffers = []
self._empty_input_buffers.append(dummy)
if cudagraphs_enabled:
self._input_buffers[i].copy_(contiguous_inputs[i])
self.context.set_tensor_address(
input_name, self._input_buffers[i].data_ptr()
)
else:
self.context.set_tensor_address(
input_name, contiguous_inputs[i].data_ptr()
input_name, tensor_to_bind.data_ptr()
)

def create_output_tensors(self) -> List[torch.Tensor]:
Expand Down
186 changes: 186 additions & 0 deletions tests/py/dynamo/runtime/test_empty_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import pytest
import torch
import torch.nn as nn
import torch_tensorrt as torchtrt
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase, run_tests

DECIMALS_OF_AGREEMENT = 5 # for output comparison


# We provide non null address to TRT
class ConcatEmptyModel(nn.Module):
def __init__(self, dim=0):
super().__init__()
self.dim = dim

def forward(self, x, y):
return torch.cat([x, y], dim=self.dim)


# TRT will handle
class ConcatEmptyModelEmptyConstant(nn.Module):
def __init__(self, dim=0):
super().__init__()
self.dim = dim

def forward(self, x):
y = torch.empty((0, 4), dtype=torch.float).cuda()
return torch.cat([x, y], dim=self.dim)


# makes use of validator
class ConcatEmptyModelEmptyConstantMisMatchDim(nn.Module):
def __init__(self, dim=0):
super().__init__()
self.dim = dim

def forward(self, x):
y = torch.tensor([], device="cuda")
return torch.cat([x, y], dim=self.dim)


class TestConcatEmptyTensor(TestCase):

@parameterized.expand(
[
(
"python_runtime_model_one_empty_0",
True,
ConcatEmptyModel,
"two_inputs",
(0,),
),
(
"cpp_runtime_model_one_empty_0",
False,
ConcatEmptyModel,
"two_inputs",
(0,),
),
(
"python_runtime_model_one_empty_0_4",
True,
ConcatEmptyModel,
"two_inputs",
(0, 4),
),
(
"cpp_runtime_model_one_empty_0_4",
False,
ConcatEmptyModel,
"two_inputs",
(0, 4),
),
(
"python_runtime_model_two_empty_0_4",
True,
ConcatEmptyModelEmptyConstant,
"one_input",
(0, 4),
),
(
"cpp_runtime_model_two_empty_0_4",
False,
ConcatEmptyModelEmptyConstant,
"one_input",
(0, 4),
),
(
"python_runtime_model_three_empty_0",
True,
ConcatEmptyModelEmptyConstantMisMatchDim,
"one_input",
(0,),
),
(
"cpp_runtime_model_three_empty_0",
False,
ConcatEmptyModelEmptyConstantMisMatchDim,
"one_input",
(0,),
),
]
)
def test_concat_empty_with_nonempty(
self, _, use_python_runtime, model_class, input_type, empty_shape
):
"""
Test concatenation of empty tensor with non-empty tensor
along a specific dimension using Torch-TensorRT compiled model.
"""
# Create model
model = model_class(dim=0).eval().cuda()

# Inputs: prepare based on model requirements
empty_input = torch.empty(empty_shape, dtype=torch.float).cuda()
non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda()

if input_type == "two_inputs":
inputs = [empty_input, non_empty_input]
else: # one_input
inputs = [non_empty_input]

# Compile with Torch-TensorRT
compiled_model = torchtrt.compile(
model,
"dynamo",
inputs,
min_block_size=5,
use_python_runtime=use_python_runtime,
)

# Run reference model
ref_out = model(*inputs)
# Run compiled model
trt_out = compiled_model(*inputs)

# Assertions
self.assertEqual(ref_out.shape, trt_out.shape)
self.assertAlmostEqual(
float(torch.max(torch.abs(ref_out - trt_out))),
0,
DECIMALS_OF_AGREEMENT,
msg="Concat with empty tensor output mismatch",
)

@parameterized.expand(
[
("python_runtime_empty_0", True, (0,)),
("cpp_runtime_empty_0", False, (0,)),
("python_runtime_empty_0_4", True, (0, 4)),
("cpp_runtime_empty_0_4", False, (0, 4)),
]
)
def test_concat_nonempty_with_empty(self, _, use_python_runtime, empty_shape):
"""
Concatenate non-empty tensor with empty tensor (opposite order)
"""
model = ConcatEmptyModel(dim=0).eval().cuda()

non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda()
empty_input = torch.empty(empty_shape, dtype=torch.float).cuda()
inputs = [non_empty_input, empty_input]

compiled_model = torchtrt.compile(
model,
"dynamo",
inputs,
min_block_size=5,
use_python_runtime=use_python_runtime,
)

ref_out = model(*inputs)
trt_out = compiled_model(*inputs)

self.assertEqual(ref_out.shape, trt_out.shape)
self.assertAlmostEqual(
float(torch.max(torch.abs(ref_out - trt_out))),
0,
DECIMALS_OF_AGREEMENT,
msg="Concat with empty tensor (opposite order) output mismatch",
)


if __name__ == "__main__":
run_tests()
Loading