Skip to content

Commit 1dcae5c

Browse files
CVCUDA backend design (#9259)
Summary: Summary ------- This PR provides the first building blocks for CV-CUDA integration in torchvision. We add the functionals `to_cvcuda_tensor` and `cvcuda_to_tensor` to transform from `torch.Tensor` to `cvcuda.Tensor` and back. We also implement the corresponding class transforms `ToCVCUDATensor` and `CVCUDAToTensor`. **Key features:** * **3-channel RGB support only**: Simplified API focusing on the most common use case (RGB images) * **Supported data types**: `torch.uint8` (RGB8 format) and `torch.float32` (RGBf32 format) * **Lossless round-trip conversions**: Exact data preservation when converting PyTorch ↔ CV-CUDA * **Informative error messages**: Helpful installation instructions when CV-CUDA is not available * **Batch-aware**: Handles both unbatched (CHW) and batched (NCHW) tensors Users must explicitly opt-in for these transforms, which require CV-CUDA to be installed. How to use ---------- ```python from PIL import Image import torchvision.transforms.v2.functional as F # Load and convert image to PyTorch tensor orig_img = Image.open("leaning_tower.jpg") img_tensor = F.pil_to_tensor(orig_img) # Convert to CV-CUDA tensor (must be 3-channel RGB on CUDA) cvcuda_tensor = F.to_cvcuda_tensor(img_tensor.cuda()) # Convert back to PyTorch tensor img_tensor = F.cvcuda_to_tensor(cvcuda_tensor) ``` > [!NOTE] > > * NVCV tensors are automatically converted to NHWC layout, contrary to torchvision's NCHW default > * Only 3-channel RGB images and 1-channel grayscale are supported for now > * Input tensors will be uploaded to CUDA device when converting to CV-CUDA tensors > * CV-CUDA must be installed: `pip install cvcuda-cu12` (CUDA 12) or `pip install cvcuda-cu11` (CUDA 11) Run unit tests -------------- ## Run unit tests ```bash pytest test/test_transforms_v2.py -k "cvcuda" ... 35 passed, 4 skipped, 9774 deselected in 1.12s ``` Test Plan: ```python from torchvision import _is_cvcuda_available _is_cvcuda_available() ``` ## Run tests ```bash buck test fbcode//mode/opt fbcode//pytorch/vision/test:torchvision_transforms_v2 ... Tests finished: Pass 38. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D85862362 Pulled By: AntoineSimoulin
1 parent acccf86 commit 1dcae5c

File tree

6 files changed

+421
-5
lines changed

6 files changed

+421
-5
lines changed

test/common_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2121
from torchvision import io, tv_tensors
2222
from torchvision.transforms._functional_tensor import _max_value as get_max_value
23-
from torchvision.transforms.v2.functional import to_image, to_pil_image
23+
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
2424
from torchvision.utils import _Image_fromarray
2525

2626

@@ -400,6 +400,10 @@ def make_image_pil(*args, **kwargs):
400400
return to_pil_image(make_image(*args, **kwargs))
401401

402402

403+
def make_image_cvcuda(*args, **kwargs):
404+
return to_cvcuda_tensor(make_image(*args, **kwargs))
405+
406+
403407
def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
404408
y = torch.randint(0, canvas_size[0], size=(num_points, 1), dtype=dtype, device=device)
405409
x = torch.randint(0, canvas_size[1], size=(num_points, 1), dtype=dtype, device=device)

test/test_transforms_v2.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
make_bounding_boxes,
3030
make_detection_masks,
3131
make_image,
32+
make_image_cvcuda,
3233
make_image_pil,
3334
make_image_tensor,
3435
make_keypoints,
@@ -51,8 +52,16 @@
5152
from torchvision.transforms.v2 import functional as F
5253
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
5354
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes
55+
from torchvision.transforms.v2.functional._type_conversion import _import_cvcuda_modules
5456
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
5557

58+
try:
59+
_import_cvcuda_modules()
60+
CVCUDA_AVAILABLE = True
61+
except ImportError:
62+
CVCUDA_AVAILABLE = False
63+
CUDA_AVAILABLE = torch.cuda.is_available()
64+
5665

5766
# turns all warnings into errors for this module
5867
pytestmark = [pytest.mark.filterwarnings("error")]
@@ -6733,6 +6742,184 @@ def test_functional_error(self):
67336742
F.pil_to_tensor(object())
67346743

67356744

6745+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6746+
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
6747+
class TestToCVCUDATensor:
6748+
"""Tests for to_cvcuda_tensor function following patterns from TestToPil"""
6749+
6750+
def test_1_channel_uint8_tensor_to_cvcuda_tensor(self):
6751+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6752+
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.uint8)
6753+
img_data = img_data.cuda()
6754+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6755+
# Check that the conversion succeeded and format is correct
6756+
assert cvcuda_img is not None
6757+
6758+
def test_1_channel_int16_tensor_to_cvcuda_tensor(self):
6759+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6760+
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int16)
6761+
img_data = img_data.cuda()
6762+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6763+
assert cvcuda_img is not None
6764+
6765+
def test_1_channel_int32_tensor_to_cvcuda_tensor(self):
6766+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6767+
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int32)
6768+
img_data = img_data.cuda()
6769+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6770+
assert cvcuda_img is not None
6771+
6772+
def test_1_channel_float32_tensor_to_cvcuda_tensor(self):
6773+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6774+
img_data = torch.rand(1, 4, 4)
6775+
img_data = img_data.cuda()
6776+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6777+
assert cvcuda_img is not None
6778+
6779+
def test_3_channel_uint8_tensor_to_cvcuda_tensor(self):
6780+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6781+
img_data = torch.randint(0, 256, (3, 4, 4), dtype=torch.uint8)
6782+
img_data = img_data.cuda()
6783+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6784+
assert cvcuda_img is not None
6785+
6786+
def test_3_channel_float32_tensor_to_cvcuda_tensor(self):
6787+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6788+
img_data = torch.rand(3, 4, 4)
6789+
img_data = img_data.cuda()
6790+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6791+
assert cvcuda_img is not None
6792+
6793+
def test_unsupported_num_channels(self):
6794+
# Test 2-channel image (not supported)
6795+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6796+
img_data = torch.rand(2, 5, 5)
6797+
img_data = img_data.cuda()
6798+
with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"):
6799+
F.to_cvcuda_tensor(img_data)
6800+
6801+
# Test 4-channel image (not supported)
6802+
img_data = torch.randint(0, 256, (4, 5, 5), dtype=torch.uint8)
6803+
img_data = img_data.cuda()
6804+
with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"):
6805+
F.to_cvcuda_tensor(img_data)
6806+
6807+
def test_invalid_input_type(self):
6808+
with pytest.raises(TypeError, match=r"pic should be `torch.Tensor`"):
6809+
F.to_cvcuda_tensor("invalid_input")
6810+
6811+
def test_invalid_dimensions(self):
6812+
# Test 1D array (too few dimensions)
6813+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6814+
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6815+
img_data = torch.randint(0, 256, (4,), dtype=torch.uint8)
6816+
img_data = img_data.cuda()
6817+
F.to_cvcuda_tensor(img_data)
6818+
6819+
# Test 2D array (no longer supported)
6820+
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6821+
img_data = torch.randint(0, 256, (4, 4), dtype=torch.uint8)
6822+
img_data = img_data.cuda()
6823+
F.to_cvcuda_tensor(img_data)
6824+
6825+
# Test 5D array (too many dimensions)
6826+
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6827+
img_data = torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8)
6828+
img_data = img_data.cuda()
6829+
F.to_cvcuda_tensor(img_data)
6830+
6831+
def test_float64_tensor_to_cvcuda_tensor(self):
6832+
# Test single channel float64 (F64 format is supported)
6833+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6834+
img_data = torch.rand(1, 4, 4, dtype=torch.float64)
6835+
img_data = img_data.cuda()
6836+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6837+
assert cvcuda_img is not None
6838+
6839+
def test_float64_rgb_not_supported(self):
6840+
# Test 3-channel float64 is NOT supported (no RGBf64 format in CV-CUDA)
6841+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6842+
img_data = torch.rand(3, 4, 4, dtype=torch.float64)
6843+
img_data = img_data.cuda()
6844+
with pytest.raises(TypeError, match=r"Unsupported dtype"):
6845+
F.to_cvcuda_tensor(img_data)
6846+
6847+
@pytest.mark.parametrize("num_channels", [1, 3])
6848+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])
6849+
def test_round_trip(self, num_channels, dtype):
6850+
# Skip float64 for 3-channel (not supported by CV-CUDA)
6851+
if num_channels == 3 and dtype == torch.float64:
6852+
pytest.skip("float64 is not supported for 3-channel RGB images")
6853+
6854+
# Setup: Create a tensor in CHW format (PyTorch standard)
6855+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6856+
if dtype == torch.uint8:
6857+
original_tensor = torch.randint(0, 256, (num_channels, 4, 4), dtype=dtype)
6858+
else:
6859+
original_tensor = torch.rand(num_channels, 4, 4, dtype=dtype)
6860+
original_tensor = original_tensor.cuda()
6861+
6862+
# Execute: Convert to CV-CUDA and back to tensor
6863+
# CHW -> (to_cvcuda_tensor) -> CV-CUDA NHWC -> (cvcuda_to_tensor) -> NCHW
6864+
cvcuda_tensor = F.to_cvcuda_tensor(original_tensor)
6865+
result_tensor = F.cvcuda_to_tensor(cvcuda_tensor)
6866+
6867+
# Remove batch dimension that was added during conversion since original was unbatched
6868+
result_tensor = result_tensor.squeeze(0)
6869+
6870+
# Assert: The round-trip conversion preserves the original tensor exactly
6871+
torch.testing.assert_close(result_tensor, original_tensor, rtol=0, atol=0)
6872+
6873+
@pytest.mark.parametrize("num_channels", [1, 3])
6874+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])
6875+
@pytest.mark.parametrize("batch_size", [1, 2, 4])
6876+
def test_round_trip_batched(self, num_channels, dtype, batch_size):
6877+
# Skip float64 for 3-channel (not supported by CV-CUDA)
6878+
if num_channels == 3 and dtype == torch.float64:
6879+
pytest.skip("float64 is not supported for 3-channel RGB images")
6880+
6881+
# Setup: Create a batched tensor in NCHW format
6882+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6883+
if dtype == torch.uint8:
6884+
original_tensor = torch.randint(0, 256, (batch_size, num_channels, 4, 4), dtype=dtype)
6885+
else:
6886+
original_tensor = torch.rand(batch_size, num_channels, 4, 4, dtype=dtype)
6887+
original_tensor = original_tensor.cuda()
6888+
6889+
# Execute: Convert to CV-CUDA and back to tensor
6890+
# NCHW -> (to_cvcuda_tensor) -> CV-CUDA NHWC -> (cvcuda_to_tensor) -> NCHW
6891+
cvcuda_tensor = F.to_cvcuda_tensor(original_tensor)
6892+
result_tensor = F.cvcuda_to_tensor(cvcuda_tensor)
6893+
6894+
# Assert: The round-trip conversion preserves the original batched tensor exactly
6895+
torch.testing.assert_close(result_tensor, original_tensor, rtol=0, atol=0)
6896+
# Also verify batch size is preserved
6897+
assert result_tensor.shape[0] == batch_size
6898+
6899+
6900+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6901+
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
6902+
class TestCVDUDAToTensor:
6903+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
6904+
@pytest.mark.parametrize(
6905+
"fn",
6906+
[F.cvcuda_to_tensor, transform_cls_to_functional(transforms.CVCUDAToTensor)],
6907+
)
6908+
def test_functional_and_transform(self, color_space, fn):
6909+
input = make_image_cvcuda(color_space=color_space)
6910+
6911+
output = fn(input)
6912+
6913+
assert isinstance(output, torch.Tensor)
6914+
# Convert input to tensor to compare sizes
6915+
input_tensor = F.cvcuda_to_tensor(input)
6916+
assert F.get_size(output) == F.get_size(input_tensor)
6917+
6918+
def test_functional_error(self):
6919+
with pytest.raises(TypeError, match="cvcuda_img should be `cvcuda.Tensor`"):
6920+
F.cvcuda_to_tensor(object())
6921+
6922+
67366923
class TestLambda:
67376924
@pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
67386925
@pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)])

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
ToDtype,
5656
)
5757
from ._temporal import UniformTemporalSubsample
58-
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
58+
from ._type_conversion import CVCUDAToTensor, PILToTensor, ToCVCUDATensor, ToImage, ToPILImage, ToPureTensor
5959
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size
6060

6161
from ._deprecated import ToTensor # usort: skip

torchvision/transforms/v2/_type_conversion.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from torchvision import tv_tensors
88
from torchvision.transforms.v2 import functional as F, Transform
9-
109
from torchvision.transforms.v2._utils import is_pure_tensor
10+
from torchvision.utils import _log_api_usage_once
1111

1212

1313
class PILToTensor(Transform):
@@ -90,3 +90,71 @@ class ToPureTensor(Transform):
9090

9191
def transform(self, inpt: Any, params: dict[str, Any]) -> torch.Tensor:
9292
return inpt.as_subclass(torch.Tensor)
93+
94+
95+
class ToCVCUDATensor:
96+
"""Convert a torch.Tensor to cvcuda.Tensor
97+
98+
This transform does not support torchscript.
99+
100+
Converts a torch.*Tensor of shape C x H x W to a cvcuda.Tensor.
101+
Only 1-channel and 3-channel images are supported.
102+
"""
103+
104+
def __init__(self):
105+
_log_api_usage_once(self)
106+
107+
def __call__(self, pic):
108+
"""
109+
Args:
110+
pic (torch.Tensor): Image to be converted to cvcuda.Tensor.
111+
112+
Returns:
113+
cvcuda.Tensor: Image converted to cvcuda.Tensor.
114+
115+
"""
116+
return F.to_cvcuda_tensor(pic)
117+
118+
def __repr__(self) -> str:
119+
return f"{self.__class__.__name__}()"
120+
121+
122+
class CVCUDAToTensor:
123+
"""Convert a `cvcuda.Tensor` to a `torch.Tensor` of the same type - this does not scale values.
124+
125+
This transform does not support torchscript.
126+
127+
Converts a `cvcuda.Tensor` to a `torch.Tensor`. Supports both batched and unbatched inputs:
128+
- Unbatched: (H, W, C) or (H, W) → (C, H, W) or (1, H, W)
129+
- Batched: (N, H, W, C) or (N, H, W) → (N, C, H, W) or (N, 1, H, W)
130+
131+
The conversion happens directly on GPU when the `cvcuda.Tensor` is stored on GPU,
132+
avoiding unnecessary data transfers.
133+
134+
Example:
135+
>>> import cvcuda
136+
>>> import torchvision.transforms.v2 as T
137+
>>> # Create a CV-CUDA tensor (320x240 RGB)
138+
>>> # Note: In CV-CUDA 0.16.0+, Image/Tensor creation uses cvcuda module
139+
>>> img_tensor = torch.randint(0, 255, (1, 240, 320, 3), dtype=torch.uint8, device="cuda")
140+
>>> cvcuda_tensor = cvcuda.as_tensor(img_tensor, cvcuda.TensorLayout.NHWC)
141+
>>> tensor = T.CVCUDAToTensor()(cvcuda_tensor)
142+
>>> print(tensor.shape)
143+
torch.Size([1, 3, 240, 320])
144+
"""
145+
146+
def __init__(self) -> None:
147+
_log_api_usage_once(self)
148+
149+
def __call__(self, pic):
150+
"""
151+
Args:
152+
pic (cvcuda.Tensor): CV-CUDA Tensor to be converted to tensor.
153+
154+
Returns:
155+
Tensor: Converted image in CHW format.
156+
"""
157+
return F.cvcuda_to_tensor(pic)
158+
159+
def __repr__(self) -> str:
160+
return f"{self.__class__.__name__}()"

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,6 @@
162162
to_dtype_video,
163163
)
164164
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
165-
from ._type_conversion import pil_to_tensor, to_image, to_pil_image
165+
from ._type_conversion import cvcuda_to_tensor, pil_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
166166

167167
from ._deprecated import get_image_size, to_tensor # usort: skip

0 commit comments

Comments
 (0)