From 49ca2f98fb87cf16574f3c7d116742975de53f7e Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Fri, 14 Nov 2025 10:48:24 -0800 Subject: [PATCH] Introducing CVCUDA Backend (#9259) Summary: Summary ------- This PR provides the first building blocks for CV-CUDA integration in torchvision. We add the functionals `to_cvcuda_tensor` and `cvcuda_to_tensor` to transform from `torch.Tensor` to `cvcuda.Tensor` and back. We also implement the corresponding class transforms `ToCVCUDATensor` and `CVCUDAToTensor`. These transforms require CV-CUDA to be installed. How to use ---------- ```python from PIL import Image import torchvision.transforms.v2.functional as F # Create rand ``torch.Tensor`` image (must be 3-channel RGB/Gray and have batch dimension) img_tensor = torch.randint(0, 256, (1, 3, 320, 240), dtype=torch.uint8) # Convert to ``cvcuda.Tensor`` (will be uploaded to CUDA) cvcuda_tensor = F.to_cvcuda_tensor(img_tensor) # Convert back to ``torch.Tensor`` img_tensor = F.cvcuda_to_tensor(cvcuda_tensor) ``` > [!NOTE] > > * ``cvcuda.Tensor`` are automatically converted to NHWC shape (since most CV-CUDA transforms only support this shape) > * Only 3-channel RGB images and 1-channel grayscale are supported for now > * Input tensors will be uploaded to CUDA device when converting to CV-CUDA tensors > * CV-CUDA must be installed: `pip install cvcuda-cu12` (CUDA 12) or `pip install cvcuda-cu11` (CUDA 11) Run unit tests -------------- ```bash pytest test/test_transforms_v2.py -k "cvcuda" ... 338 passed, 9774 deselected in 1.65s ``` Differential Revision: D85862362 Pulled By: AntoineSimoulin --- test/common_utils.py | 6 +- test/test_transforms_v2.py | 99 ++++++++++++++++++- torchvision/transforms/v2/__init__.py | 2 +- torchvision/transforms/v2/_type_conversion.py | 35 ++++++- .../transforms/v2/functional/__init__.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 25 ++++- .../v2/functional/_type_conversion.py | 39 +++++++- .../transforms/v2/functional/_utils.py | 29 ++++++ 8 files changed, 228 insertions(+), 9 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 74ad31fea72..8c3c9dd58a8 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,7 +20,7 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_image, to_pil_image +from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image from torchvision.utils import _Image_fromarray @@ -400,6 +400,10 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, **kwargs): + return to_cvcuda_tensor(make_image(*args, **kwargs)) + + def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): y = torch.randint(0, canvas_size[0], size=(num_points, 1), dtype=dtype, device=device) x = torch.randint(0, canvas_size[1], size=(num_points, 1), dtype=dtype, device=device) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 00ed6a8aef1..670a9d00ffb 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -29,6 +29,7 @@ make_bounding_boxes, make_detection_masks, make_image, + make_image_cvcuda, make_image_pil, make_image_tensor, make_keypoints, @@ -51,8 +52,17 @@ from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2._utils import check_type, is_pure_tensor from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes -from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal +from torchvision.transforms.v2.functional._utils import ( + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_kernel_internal, +) + +CVCUDA_AVAILABLE = _is_cvcuda_available() +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # turns all warnings into errors for this module pytestmark = [pytest.mark.filterwarnings("error")] @@ -6732,6 +6742,93 @@ def test_functional_error(self): F.pil_to_tensor(object()) +@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") +@needs_cuda +class TestToCVCUDATensor: + @pytest.mark.parametrize("image_type", (torch.Tensor, tv_tensors.Image)) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + @pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)]) + @pytest.mark.parametrize( + "fn", + [F.to_cvcuda_tensor, transform_cls_to_functional(transforms.ToCVCUDATensor)], + ) + def test_functional_and_transform(self, image_type, dtype, device, color_space, batch_dims, fn): + image = make_image(dtype=dtype, device=device, color_space=color_space, batch_dims=batch_dims) + if image_type is torch.Tensor: + image = image.as_subclass(torch.Tensor) + assert is_pure_tensor(image) + output = fn(image) + + assert isinstance(output, cvcuda.Tensor) + assert F.get_size(output) == F.get_size(image) + assert output is not None + + def test_invalid_input_type(self): + with pytest.raises(TypeError, match=r"inpt should be ``torch.Tensor``"): + F.to_cvcuda_tensor("invalid_input") + + def test_invalid_dimensions(self): + with pytest.raises(ValueError, match=r"pic should be 4 dimensional"): + img_data = torch.randint(0, 256, (3, 1, 3), dtype=torch.uint8) + img_data = img_data.cuda() + F.to_cvcuda_tensor(img_data) + + with pytest.raises(ValueError, match=r"pic should be 4 dimensional"): + img_data = torch.randint(0, 256, (4,), dtype=torch.uint8) + img_data = img_data.cuda() + F.to_cvcuda_tensor(img_data) + + with pytest.raises(ValueError, match=r"pic should be 4 dimensional"): + img_data = torch.randint(0, 256, (4, 4), dtype=torch.uint8) + img_data = img_data.cuda() + F.to_cvcuda_tensor(img_data) + + with pytest.raises(ValueError, match=r"pic should be 4 dimensional"): + img_data = torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8) + img_data = img_data.cuda() + F.to_cvcuda_tensor(img_data) + + @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + @pytest.mark.parametrize("batch_size", [1, 2, 4]) + def test_round_trip(self, dtype, device, color_space, batch_size): + original_tensor = make_image_tensor( + dtype=dtype, device=device, color_space=color_space, batch_dims=(batch_size,) + ) + cvcuda_tensor = F.to_cvcuda_tensor(original_tensor) + result_tensor = F.cvcuda_to_tensor(cvcuda_tensor) + torch.testing.assert_close(result_tensor.to(device), original_tensor, rtol=0, atol=0) + assert result_tensor.shape[0] == batch_size + + +@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") +@needs_cuda +class TestCVDUDAToTensor: + @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + @pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)]) + @pytest.mark.parametrize( + "fn", + [F.cvcuda_to_tensor, transform_cls_to_functional(transforms.CVCUDAToTensor)], + ) + def test_functional_and_transform(self, dtype, device, color_space, batch_dims, fn): + input = make_image_cvcuda(dtype=dtype, device=device, color_space=color_space, batch_dims=batch_dims) + + output = fn(input) + + assert isinstance(output, torch.Tensor) + input_tensor = F.cvcuda_to_tensor(input) + assert F.get_size(output) == F.get_size(input_tensor) + + def test_functional_error(self): + with pytest.raises(TypeError, match="cvcuda_img should be `cvcuda.Tensor`"): + F.cvcuda_to_tensor(object()) + + class TestLambda: @pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0]) @pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)]) diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 895bf6e2f71..d69a0f47a51 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -55,7 +55,7 @@ ToDtype, ) from ._temporal import UniformTemporalSubsample -from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor +from ._type_conversion import CVCUDAToTensor, PILToTensor, ToCVCUDATensor, ToImage, ToPILImage, ToPureTensor from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size from ._deprecated import ToTensor # usort: skip diff --git a/torchvision/transforms/v2/_type_conversion.py b/torchvision/transforms/v2/_type_conversion.py index 7cac62868b9..56647d4b856 100644 --- a/torchvision/transforms/v2/_type_conversion.py +++ b/torchvision/transforms/v2/_type_conversion.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import numpy as np import PIL.Image @@ -6,8 +6,11 @@ from torchvision import tv_tensors from torchvision.transforms.v2 import functional as F, Transform - from torchvision.transforms.v2._utils import is_pure_tensor +from torchvision.transforms.v2.functional._utils import _import_cvcuda + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] class PILToTensor(Transform): @@ -90,3 +93,31 @@ class ToPureTensor(Transform): def transform(self, inpt: Any, params: dict[str, Any]) -> torch.Tensor: return inpt.as_subclass(torch.Tensor) + + +class ToCVCUDATensor(Transform): + """Convert a ``torch.Tensor`` with NCHW shape to a ``cvcuda.Tensor``. + If the input tensor is on CPU, it will automatically be transferred to GPU. + Only 1-channel and 3-channel images are supported. + + This transform does not support torchscript. + """ + + def transform(self, inpt: torch.Tensor, params: dict[str, Any]) -> "cvcuda.Tensor": + return F.to_cvcuda_tensor(inpt) + + +class CVCUDAToTensor(Transform): + """Convert a ``cvcuda.Tensor`` to a ``torch.Tensor`` with NCHW shape. + + This function does not support torchscript. + """ + + try: + cvcuda = _import_cvcuda() + _transformed_types = (cvcuda.Tensor,) + except ImportError: + pass + + def transform(self, inpt: Any, params: dict[str, Any]) -> torch.Tensor: + return F.cvcuda_to_tensor(inpt) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 13fbaa588fe..032a993b1f0 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -162,6 +162,6 @@ to_dtype_video, ) from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video -from ._type_conversion import pil_to_tensor, to_image, to_pil_image +from ._type_conversion import cvcuda_to_tensor, pil_to_tensor, to_cvcuda_tensor, to_image, to_pil_image from ._deprecated import get_image_size, to_tensor # usort: skip diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 4568b39ab59..6b8f19f12f4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -9,7 +9,14 @@ from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def get_dimensions(inpt: torch.Tensor) -> list[int]: @@ -107,6 +114,20 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]: return [height, width] +def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: + """Get size of `cvcuda.Tensor` with NHWC layout.""" + hw = list(image.shape[-3:-1]) + ndims = len(hw) + if ndims == 2: + return hw + else: + raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") + + +if CVCUDA_AVAILABLE: + _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) + + @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) def get_size_video(video: torch.Tensor) -> list[int]: return get_size_image(video) diff --git a/torchvision/transforms/v2/functional/_type_conversion.py b/torchvision/transforms/v2/functional/_type_conversion.py index c5a731fe143..239025bbe1b 100644 --- a/torchvision/transforms/v2/functional/_type_conversion.py +++ b/torchvision/transforms/v2/functional/_type_conversion.py @@ -1,10 +1,16 @@ -from typing import Union +from typing import TYPE_CHECKING, Union import numpy as np import PIL.Image import torch from torchvision import tv_tensors from torchvision.transforms import functional as _F +from torchvision.utils import _log_api_usage_once + +from ._utils import _import_cvcuda + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] @torch.jit.unused @@ -25,3 +31,34 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tenso to_pil_image = _F.to_pil_image pil_to_tensor = _F.pil_to_tensor + + +@torch.jit.unused +def to_cvcuda_tensor(inpt: torch.Tensor) -> "cvcuda.Tensor": + """See :class:``~torchvision.transforms.v2.ToCVCUDATensor`` for details.""" + cvcuda = _import_cvcuda() + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(to_cvcuda_tensor) + if not isinstance(inpt, (torch.Tensor, tv_tensors.Image)): + raise TypeError(f"inpt should be ``torch.Tensor`` or ``tv_tensors.Image``. Got {type(inpt)}.") + if inpt.ndim != 4: + raise ValueError(f"pic should be 4 dimensional. Got {inpt.ndim} dimensions.") + # Convert to NHWC as CVCUDA transforms do not support NCHW + inpt = inpt.permute(0, 2, 3, 1) + return cvcuda.as_tensor(inpt.cuda().contiguous(), cvcuda.TensorLayout.NHWC) + + +@torch.jit.unused +def cvcuda_to_tensor(cvcuda_img: "cvcuda.Tensor") -> torch.Tensor: + """See :class:``~torchvision.transforms.v2.CVCUDAToTensor`` for details.""" + cvcuda = _import_cvcuda() + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(cvcuda_to_tensor) + if not isinstance(cvcuda_img, cvcuda.Tensor): + raise TypeError(f"cvcuda_img should be ``cvcuda.Tensor``. Got {type(cvcuda_img)}.") + cuda_tensor = torch.as_tensor(cvcuda_img.cuda(), device="cuda") + if cvcuda_img.ndim != 4: + raise ValueError(f"Image should be 4 dimensional. Got {cuda_tensor.ndim} dimensions.") + # Convert to NCHW shape from CVCUDA default NHWC + img = cuda_tensor.permute(0, 3, 1, 2) + return img diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index b857285c891..ad1eddd258b 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -140,3 +140,32 @@ def decorator(kernel): return kernel return decorator + + +def _import_cvcuda(): + """Import CV-CUDA modules with informative error message if not installed. + + Returns: + cvcuda module. + + Raises: + RuntimeError: If CV-CUDA is not installed. + """ + try: + import cvcuda # type: ignore[import-not-found] + + return cvcuda + except ImportError as e: + raise ImportError( + "CV-CUDA is required but not installed. " + "Please install it following the instructions at " + "https://github.com/CVCUDA/CV-CUDA." + ) from e + + +def _is_cvcuda_available(): + try: + _ = _import_cvcuda() + return True + except ImportError: + return False