From 1a12a5e36151d53e8acb4208dd4720073e72a91e Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 6 Nov 2025 13:50:18 -0800 Subject: [PATCH] Empty tensor handling --- core/runtime/TRTEngine.cpp | 5 + core/runtime/TRTEngine.h | 3 + core/runtime/execute_engine.cpp | 21 +- .../dynamo/conversion/aten_ops_converters.py | 51 +++++ .../runtime/_PythonTorchTensorRTModule.py | 14 +- tests/py/dynamo/runtime/test_empty_input.py | 186 ++++++++++++++++++ 6 files changed, 273 insertions(+), 7 deletions(-) create mode 100644 tests/py/dynamo/runtime/test_empty_input.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 52a9b47c12..5484d159c9 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -249,6 +249,11 @@ TRTEngine::~TRTEngine() { trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); + for (void* ptr : empty_input_ptrs) { + if (ptr) + cudaFree(ptr); + } + empty_input_ptrs.clear(); rt.reset(); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 15d723ce4e..5f99d41f45 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -177,6 +177,9 @@ struct TRTEngine : torch::CustomClassHolder { bool use_pre_allocated_outputs = false; std::vector pre_allocated_outputs; + // Empty Input Pointers + std::vector empty_input_ptrs = {}; + // Output Allocator-Related Functionality bool requires_output_allocator = false; // engine requires output allocator bool use_output_allocator_outputs = false; // users specify to use output allocator diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..0c6e307966 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -129,6 +129,7 @@ void setup_input_tensors( compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()), "Error while setting the tensor address for shape inputs"); + void* tensor_addr = nullptr; if (cudagraphs_enabled) { // @peri044 I dont know if this makes sense since they are supposed to be GPU buffers compiled_engine->input_buffers[i] = input_cpu; @@ -152,15 +153,23 @@ void setup_input_tensors( if (cudagraphs_enabled) { // If using CUDAGraphs copy formatted input to the corresponding persistent input buffer compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true); - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), compiled_engine->input_buffers[i].data_ptr()), - "Error while setting the input tensor address for inputs"); + tensor_addr = compiled_engine->input_buffers[i].data_ptr(); } else { // Otherwise use the formatted buffer directly - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), formatted_inputs.back().data_ptr()), - "Error while setting the input tensor address for inputs"); + tensor_addr = formatted_inputs.back().data_ptr(); } + // handle empty tensors→ TensorRT requires non-null address even if numel() = 0 + size_t nbytes = final_input.numel() * final_input.element_size(); + if (nbytes == 0 || tensor_addr == nullptr) { + void* dummy = nullptr; + cudaMalloc(&dummy, 1); // allocate 1 byte GPU buffer to satisfy TRT and get a non-null address + tensor_addr = dummy; + compiled_engine->empty_input_ptrs.push_back(dummy); // track to free later + } + + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), tensor_addr), + "Failed to bind tensor address for " << name); } } } diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 164f0c1065..d4bf0cbf5b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -217,6 +217,57 @@ def aten_ops_native_group_norm( ) +def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: + """ + Validator for torch.cat operation with empty tensor handling. + + PyTorch allows torch.tensor([]) (shape (0,)) to be concatenated with higher-dimensional + tensors, but TensorRT requires all inputs to have the same rank. This validator catches + this specific edge case. + + Example valid case: cat([(3, 4), (0, 4)], dim=0) - same rank, properly shaped empty tensor for TRT + Example invalid case: cat([(3, 4), (0,)], dim=0) - torch.tensor([]) with rank mismatch + """ + inputs = node.args[0] + + if len(inputs) < 2: + return True + + # Collect metadata for all inputs + input_metas = [] + for inp in inputs: + if isinstance(inp, TRTTensor): + # TRTTensor has shape directly + input_metas.append(inp.shape) + else: + # For nodes, get metadata + meta = getattr(inp, "meta", {}).get("tensor_meta") + if meta is None: + # Can't validate without metadata, allow it + return True + shape = tuple(meta.shape) + input_metas.append(shape) + + # Check for the specific problematic case: + # 1D empty tensor (0,) being concatenated with higher-dimensional tensors + ranks = [len(shape) for shape in input_metas] + # If all ranks are the same, it's fine (PyTorch and TensorRT both handle this) + if len(set(ranks)) == 1: + return True + # If ranks differ, check if we have a 1D empty tensor (0,) in the mix + # This is the torch.tensor([]) case that PyTorch allows but TensorRT doesn't + for i, shape in enumerate(input_metas): + if shape == (0,) or (len(shape) == 1 and shape[0] == 0): + # Found a 1D empty tensor with rank mismatch + _LOGGER.debug( + f"Concatenation rejected by TRT, torch.tensor([]) or 1D empty tensor at position {i} " + f"PyTorch allows this but TensorRT requires all inputs to have the same rank. " + f"Use torch.empty((0, ...)) with explicit dimensions matching other inputs instead. Falling back to Pytorch" + ) + return False + return True + + @dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True) def aten_ops_cat( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index d18a5674e0..4381b2059a 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -392,6 +392,18 @@ def setup_input_tensors( self.context.set_input_shape( input_name, tuple(contiguous_inputs[i].shape) ) + tensor_to_bind = contiguous_inputs[i] + if tensor_to_bind.numel() == 0: + # this is used to provide valid memory address to TRT + dummy = torch.empty( + 1, + dtype=tensor_to_bind.dtype, + device=torch.cuda.current_device(), + ) + tensor_to_bind = dummy + if not hasattr(self, "_empty_input_buffers"): + self._empty_input_buffers = [] + self._empty_input_buffers.append(dummy) if cudagraphs_enabled: self._input_buffers[i].copy_(contiguous_inputs[i]) self.context.set_tensor_address( @@ -399,7 +411,7 @@ def setup_input_tensors( ) else: self.context.set_tensor_address( - input_name, contiguous_inputs[i].data_ptr() + input_name, tensor_to_bind.data_ptr() ) def create_output_tensors(self) -> List[torch.Tensor]: diff --git a/tests/py/dynamo/runtime/test_empty_input.py b/tests/py/dynamo/runtime/test_empty_input.py new file mode 100644 index 0000000000..49a6333cec --- /dev/null +++ b/tests/py/dynamo/runtime/test_empty_input.py @@ -0,0 +1,186 @@ +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase, run_tests + +DECIMALS_OF_AGREEMENT = 5 # for output comparison + + +# We provide non null address to TRT +class ConcatEmptyModel(nn.Module): + def __init__(self, dim=0): + super().__init__() + self.dim = dim + + def forward(self, x, y): + return torch.cat([x, y], dim=self.dim) + + +# TRT will handle +class ConcatEmptyModelEmptyConstant(nn.Module): + def __init__(self, dim=0): + super().__init__() + self.dim = dim + + def forward(self, x): + y = torch.empty((0, 4), dtype=torch.float).cuda() + return torch.cat([x, y], dim=self.dim) + + +# makes use of validator +class ConcatEmptyModelEmptyConstantMisMatchDim(nn.Module): + def __init__(self, dim=0): + super().__init__() + self.dim = dim + + def forward(self, x): + y = torch.tensor([], device="cuda") + return torch.cat([x, y], dim=self.dim) + + +class TestConcatEmptyTensor(TestCase): + + @parameterized.expand( + [ + ( + "python_runtime_model_one_empty_0", + True, + ConcatEmptyModel, + "two_inputs", + (0,), + ), + ( + "cpp_runtime_model_one_empty_0", + False, + ConcatEmptyModel, + "two_inputs", + (0,), + ), + ( + "python_runtime_model_one_empty_0_4", + True, + ConcatEmptyModel, + "two_inputs", + (0, 4), + ), + ( + "cpp_runtime_model_one_empty_0_4", + False, + ConcatEmptyModel, + "two_inputs", + (0, 4), + ), + ( + "python_runtime_model_two_empty_0_4", + True, + ConcatEmptyModelEmptyConstant, + "one_input", + (0, 4), + ), + ( + "cpp_runtime_model_two_empty_0_4", + False, + ConcatEmptyModelEmptyConstant, + "one_input", + (0, 4), + ), + ( + "python_runtime_model_three_empty_0", + True, + ConcatEmptyModelEmptyConstantMisMatchDim, + "one_input", + (0,), + ), + ( + "cpp_runtime_model_three_empty_0", + False, + ConcatEmptyModelEmptyConstantMisMatchDim, + "one_input", + (0,), + ), + ] + ) + def test_concat_empty_with_nonempty( + self, _, use_python_runtime, model_class, input_type, empty_shape + ): + """ + Test concatenation of empty tensor with non-empty tensor + along a specific dimension using Torch-TensorRT compiled model. + """ + # Create model + model = model_class(dim=0).eval().cuda() + + # Inputs: prepare based on model requirements + empty_input = torch.empty(empty_shape, dtype=torch.float).cuda() + non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda() + + if input_type == "two_inputs": + inputs = [empty_input, non_empty_input] + else: # one_input + inputs = [non_empty_input] + + # Compile with Torch-TensorRT + compiled_model = torchtrt.compile( + model, + "dynamo", + inputs, + min_block_size=5, + use_python_runtime=use_python_runtime, + ) + + # Run reference model + ref_out = model(*inputs) + # Run compiled model + trt_out = compiled_model(*inputs) + + # Assertions + self.assertEqual(ref_out.shape, trt_out.shape) + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - trt_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Concat with empty tensor output mismatch", + ) + + @parameterized.expand( + [ + ("python_runtime_empty_0", True, (0,)), + ("cpp_runtime_empty_0", False, (0,)), + ("python_runtime_empty_0_4", True, (0, 4)), + ("cpp_runtime_empty_0_4", False, (0, 4)), + ] + ) + def test_concat_nonempty_with_empty(self, _, use_python_runtime, empty_shape): + """ + Concatenate non-empty tensor with empty tensor (opposite order) + """ + model = ConcatEmptyModel(dim=0).eval().cuda() + + non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda() + empty_input = torch.empty(empty_shape, dtype=torch.float).cuda() + inputs = [non_empty_input, empty_input] + + compiled_model = torchtrt.compile( + model, + "dynamo", + inputs, + min_block_size=5, + use_python_runtime=use_python_runtime, + ) + + ref_out = model(*inputs) + trt_out = compiled_model(*inputs) + + self.assertEqual(ref_out.shape, trt_out.shape) + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - trt_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Concat with empty tensor (opposite order) output mismatch", + ) + + +if __name__ == "__main__": + run_tests()