Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_image, to_pil_image
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
from torchvision.utils import _Image_fromarray


Expand Down Expand Up @@ -400,6 +400,10 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_image_cvcuda(*args, **kwargs):
return to_cvcuda_tensor(make_image(*args, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
y = torch.randint(0, canvas_size[0], size=(num_points, 1), dtype=dtype, device=device)
x = torch.randint(0, canvas_size[1], size=(num_points, 1), dtype=dtype, device=device)
Expand Down
99 changes: 98 additions & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
make_bounding_boxes,
make_detection_masks,
make_image,
make_image_cvcuda,
make_image_pil,
make_image_tensor,
make_keypoints,
Expand All @@ -51,8 +52,17 @@
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
from torchvision.transforms.v2.functional._utils import (
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
_register_kernel_internal,
)


CVCUDA_AVAILABLE = _is_cvcuda_available()
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda()

# turns all warnings into errors for this module
pytestmark = [pytest.mark.filterwarnings("error")]
Expand Down Expand Up @@ -6732,6 +6742,93 @@ def test_functional_error(self):
F.pil_to_tensor(object())


@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
@needs_cuda
class TestToCVCUDATensor:
@pytest.mark.parametrize("image_type", (torch.Tensor, tv_tensors.Image))
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
@pytest.mark.parametrize(
"fn",
[F.to_cvcuda_tensor, transform_cls_to_functional(transforms.ToCVCUDATensor)],
)
def test_functional_and_transform(self, image_type, dtype, device, color_space, batch_dims, fn):
image = make_image(dtype=dtype, device=device, color_space=color_space, batch_dims=batch_dims)
if image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_pure_tensor(image)
output = fn(image)

assert isinstance(output, cvcuda.Tensor)
assert F.get_size(output) == F.get_size(image)
assert output is not None

def test_invalid_input_type(self):
with pytest.raises(TypeError, match=r"inpt should be ``torch.Tensor``"):
F.to_cvcuda_tensor("invalid_input")

def test_invalid_dimensions(self):
with pytest.raises(ValueError, match=r"pic should be 4 dimensional"):
img_data = torch.randint(0, 256, (3, 1, 3), dtype=torch.uint8)
img_data = img_data.cuda()
F.to_cvcuda_tensor(img_data)

with pytest.raises(ValueError, match=r"pic should be 4 dimensional"):
img_data = torch.randint(0, 256, (4,), dtype=torch.uint8)
img_data = img_data.cuda()
F.to_cvcuda_tensor(img_data)

with pytest.raises(ValueError, match=r"pic should be 4 dimensional"):
img_data = torch.randint(0, 256, (4, 4), dtype=torch.uint8)
img_data = img_data.cuda()
F.to_cvcuda_tensor(img_data)

with pytest.raises(ValueError, match=r"pic should be 4 dimensional"):
img_data = torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8)
img_data = img_data.cuda()
F.to_cvcuda_tensor(img_data)

@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("batch_size", [1, 2, 4])
def test_round_trip(self, dtype, device, color_space, batch_size):
original_tensor = make_image_tensor(
dtype=dtype, device=device, color_space=color_space, batch_dims=(batch_size,)
)
cvcuda_tensor = F.to_cvcuda_tensor(original_tensor)
result_tensor = F.cvcuda_to_tensor(cvcuda_tensor)
torch.testing.assert_close(result_tensor.to(device), original_tensor, rtol=0, atol=0)
assert result_tensor.shape[0] == batch_size


@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
@needs_cuda
class TestCVDUDAToTensor:
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
@pytest.mark.parametrize(
"fn",
[F.cvcuda_to_tensor, transform_cls_to_functional(transforms.CVCUDAToTensor)],
)
def test_functional_and_transform(self, dtype, device, color_space, batch_dims, fn):
input = make_image_cvcuda(dtype=dtype, device=device, color_space=color_space, batch_dims=batch_dims)

output = fn(input)

assert isinstance(output, torch.Tensor)
input_tensor = F.cvcuda_to_tensor(input)
assert F.get_size(output) == F.get_size(input_tensor)

def test_functional_error(self):
with pytest.raises(TypeError, match="cvcuda_img should be `cvcuda.Tensor`"):
F.cvcuda_to_tensor(object())


class TestLambda:
@pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
@pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)])
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
ToDtype,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
from ._type_conversion import CVCUDAToTensor, PILToTensor, ToCVCUDATensor, ToImage, ToPILImage, ToPureTensor
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size

from ._deprecated import ToTensor # usort: skip
35 changes: 33 additions & 2 deletions torchvision/transforms/v2/_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Any, Optional, Union
from typing import Any, Optional, TYPE_CHECKING, Union

import numpy as np
import PIL.Image
import torch

from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform

from torchvision.transforms.v2._utils import is_pure_tensor
from torchvision.transforms.v2.functional._utils import _import_cvcuda

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]


class PILToTensor(Transform):
Expand Down Expand Up @@ -90,3 +93,31 @@ class ToPureTensor(Transform):

def transform(self, inpt: Any, params: dict[str, Any]) -> torch.Tensor:
return inpt.as_subclass(torch.Tensor)


class ToCVCUDATensor(Transform):
"""Convert a ``torch.Tensor`` with NCHW shape to a ``cvcuda.Tensor``.
If the input tensor is on CPU, it will automatically be transferred to GPU.
Only 1-channel and 3-channel images are supported.
This transform does not support torchscript.
"""

def transform(self, inpt: torch.Tensor, params: dict[str, Any]) -> "cvcuda.Tensor":
return F.to_cvcuda_tensor(inpt)


class CVCUDAToTensor(Transform):
"""Convert a ``cvcuda.Tensor`` to a ``torch.Tensor`` with NCHW shape.
This function does not support torchscript.
"""

try:
cvcuda = _import_cvcuda()
_transformed_types = (cvcuda.Tensor,)
except ImportError:
pass

def transform(self, inpt: Any, params: dict[str, Any]) -> torch.Tensor:
return F.cvcuda_to_tensor(inpt)
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,6 @@
to_dtype_video,
)
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
from ._type_conversion import pil_to_tensor, to_image, to_pil_image
from ._type_conversion import cvcuda_to_tensor, pil_to_tensor, to_cvcuda_tensor, to_image, to_pil_image

from ._deprecated import get_image_size, to_tensor # usort: skip
25 changes: 23 additions & 2 deletions torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, TYPE_CHECKING, Union

import PIL.Image
import torch
Expand All @@ -9,7 +9,14 @@

from torchvision.utils import _log_api_usage_once

from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor

CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def get_dimensions(inpt: torch.Tensor) -> list[int]:
Expand Down Expand Up @@ -107,6 +114,20 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]:
return [height, width]


def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
"""Get size of `cvcuda.Tensor` with NHWC layout."""
hw = list(image.shape[-3:-1])
ndims = len(hw)
if ndims == 2:
return hw
else:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")


if CVCUDA_AVAILABLE:
_get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda)


@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
def get_size_video(video: torch.Tensor) -> list[int]:
return get_size_image(video)
Expand Down
39 changes: 38 additions & 1 deletion torchvision/transforms/v2/functional/_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import Union
from typing import TYPE_CHECKING, Union

import numpy as np
import PIL.Image
import torch
from torchvision import tv_tensors
from torchvision.transforms import functional as _F
from torchvision.utils import _log_api_usage_once

from ._utils import _import_cvcuda

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]


@torch.jit.unused
Expand All @@ -25,3 +31,34 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tenso

to_pil_image = _F.to_pil_image
pil_to_tensor = _F.pil_to_tensor


@torch.jit.unused
def to_cvcuda_tensor(inpt: torch.Tensor) -> "cvcuda.Tensor":
"""See :class:``~torchvision.transforms.v2.ToCVCUDATensor`` for details."""
cvcuda = _import_cvcuda()
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_cvcuda_tensor)
if not isinstance(inpt, (torch.Tensor, tv_tensors.Image)):
raise TypeError(f"inpt should be ``torch.Tensor`` or ``tv_tensors.Image``. Got {type(inpt)}.")
if inpt.ndim != 4:
raise ValueError(f"pic should be 4 dimensional. Got {inpt.ndim} dimensions.")
# Convert to NHWC as CVCUDA transforms do not support NCHW
inpt = inpt.permute(0, 2, 3, 1)
return cvcuda.as_tensor(inpt.cuda().contiguous(), cvcuda.TensorLayout.NHWC)


@torch.jit.unused
def cvcuda_to_tensor(cvcuda_img: "cvcuda.Tensor") -> torch.Tensor:
"""See :class:``~torchvision.transforms.v2.CVCUDAToTensor`` for details."""
cvcuda = _import_cvcuda()
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(cvcuda_to_tensor)
if not isinstance(cvcuda_img, cvcuda.Tensor):
raise TypeError(f"cvcuda_img should be ``cvcuda.Tensor``. Got {type(cvcuda_img)}.")
cuda_tensor = torch.as_tensor(cvcuda_img.cuda(), device="cuda")
if cvcuda_img.ndim != 4:
raise ValueError(f"Image should be 4 dimensional. Got {cuda_tensor.ndim} dimensions.")
# Convert to NCHW shape from CVCUDA default NHWC
img = cuda_tensor.permute(0, 3, 1, 2)
return img
29 changes: 29 additions & 0 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,32 @@ def decorator(kernel):
return kernel

return decorator


def _import_cvcuda():
"""Import CV-CUDA modules with informative error message if not installed.
Returns:
cvcuda module.
Raises:
RuntimeError: If CV-CUDA is not installed.
"""
try:
import cvcuda # type: ignore[import-not-found]

return cvcuda
except ImportError as e:
raise ImportError(
"CV-CUDA is required but not installed. "
"Please install it following the instructions at "
"https://github.com/CVCUDA/CV-CUDA."
) from e


def _is_cvcuda_available():
try:
_ = _import_cvcuda()
return True
except ImportError:
return False