Skip to content

Commit eea836c

Browse files
CVCUDA backend design (#9259)
Summary: Users have to explicitly opt-in for those transforms. Here we provide the first building block for this interface. We add the functionals `to_nvcv_image` and `nvcv_to_tensor` to transform `torch.Tensor` to `nvcv.Tensor`. We also implement the corresponding class transforms `ToNVCVImage` and `NVCVToTensor`. ## How to use ```python from PIL import Image import torchvision.transforms.v2.functional as F orig_img = Image.open("leaning_tower.jpg") img_tensor = F.pil_to_tensor(orig_img) nvcv_tensor = F.to_nvcv_tensor(img_tensor.cuda()) img_tensor = F.nvcv_to_tensor(nvcv_tensor) ``` > [!NOTE] > NVCV tensors are automatically converted to NHWC format. Contrary to torchvision convention, which relies on NCHW format. ## Run unit tests ```bash pytest test/test_cvcuda.py ... 37 passed in 0.15s ``` Test Plan: ```python from torchvision import _is_cvcuda_available _is_cvcuda_available() ``` ## Run tests ```bash buck test fbcode//mode/opt fbcode//pytorch/vision/test:torchvision_cvcuda ... Tests finished: Pass 38. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D85862362 Pulled By: AntoineSimoulin
1 parent 167730b commit eea836c

File tree

6 files changed

+544
-0
lines changed

6 files changed

+544
-0
lines changed

test/test_cvcuda.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import pytest
2+
import torch
3+
from torchvision import _is_cvcuda_available
4+
from torchvision.transforms.v2 import functional as F
5+
6+
CVCUDA_AVAILABLE = _is_cvcuda_available()
7+
CUDA_AVAILABLE = torch.cuda.is_available()
8+
9+
10+
if CVCUDA_AVAILABLE:
11+
import nvcv
12+
13+
14+
@pytest.mark.skipif(CVCUDA_AVAILABLE is False, reason="test requires CVCUDA")
15+
@pytest.mark.skipif(CUDA_AVAILABLE is False, reason="test requires CUDA")
16+
class TestToNvcvTensor:
17+
"""Tests for to_nvcv_tensor function following patterns from TestToPil"""
18+
19+
def test_1_channel_uint8_tensor_to_nvcv_tensor(self):
20+
img_data = torch.ByteTensor(1, 4, 4).random_(0, 255).cuda()
21+
nvcv_img = F.to_nvcv_tensor(img_data)
22+
# Check that the conversion succeeded and format is correct
23+
assert nvcv_img is not None
24+
25+
def test_1_channel_int16_tensor_to_nvcv_tensor(self):
26+
img_data = torch.ShortTensor(1, 4, 4).random_().cuda()
27+
nvcv_img = F.to_nvcv_tensor(img_data)
28+
assert nvcv_img is not None
29+
30+
def test_1_channel_int32_tensor_to_nvcv_tensor(self):
31+
img_data = torch.IntTensor(1, 4, 4).random_().cuda()
32+
nvcv_img = F.to_nvcv_tensor(img_data)
33+
assert nvcv_img is not None
34+
35+
def test_1_channel_float32_tensor_to_nvcv_tensor(self):
36+
img_data = torch.Tensor(1, 4, 4).uniform_().cuda()
37+
nvcv_img = F.to_nvcv_tensor(img_data)
38+
assert nvcv_img is not None
39+
40+
def test_2_channel_uint8_tensor_to_nvcv_tensor(self):
41+
img_data = torch.ByteTensor(2, 4, 4).random_(0, 255).cuda()
42+
# NVCV doesn't support 2-channel uint8 images
43+
with pytest.raises(TypeError, match="Unsupported dtype.*for 2-channel image"):
44+
F.to_nvcv_tensor(img_data)
45+
46+
def test_2_channel_float32_tensor_to_nvcv_tensor(self):
47+
img_data = torch.Tensor(2, 4, 4).uniform_().cuda()
48+
nvcv_img = F.to_nvcv_tensor(img_data)
49+
assert nvcv_img is not None
50+
51+
def test_3_channel_uint8_tensor_to_nvcv_tensor(self):
52+
img_data = torch.ByteTensor(3, 4, 4).random_(0, 255).cuda()
53+
nvcv_img = F.to_nvcv_tensor(img_data)
54+
assert nvcv_img is not None
55+
56+
def test_3_channel_float32_tensor_to_nvcv_tensor(self):
57+
img_data = torch.Tensor(3, 4, 4).uniform_().cuda()
58+
nvcv_img = F.to_nvcv_tensor(img_data)
59+
assert nvcv_img is not None
60+
61+
def test_4_channel_uint8_tensor_to_nvcv_tensor(self):
62+
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).cuda()
63+
nvcv_img = F.to_nvcv_tensor(img_data)
64+
assert nvcv_img is not None
65+
66+
def test_4_channel_float32_tensor_to_nvcv_tensor(self):
67+
img_data = torch.Tensor(4, 4, 4).uniform_().cuda()
68+
nvcv_img = F.to_nvcv_tensor(img_data)
69+
assert nvcv_img is not None
70+
71+
def test_2d_uint8_tensor_to_nvcv_tensor(self):
72+
img_data = torch.ByteTensor(4, 4).random_(0, 255).cuda()
73+
nvcv_img = F.to_nvcv_tensor(img_data)
74+
assert nvcv_img is not None
75+
76+
def test_2d_float32_tensor_to_nvcv_tensor(self):
77+
img_data = torch.Tensor(4, 4).uniform_().cuda()
78+
nvcv_img = F.to_nvcv_tensor(img_data)
79+
assert nvcv_img is not None
80+
81+
def test_1_channel_uint8_ndarray_to_nvcv_tensor(self):
82+
img_data = torch.ByteTensor(4, 4, 1).random_(0, 255).numpy()
83+
nvcv_img = F.to_nvcv_tensor(img_data)
84+
assert nvcv_img is not None
85+
86+
def test_3_channel_uint8_ndarray_to_nvcv_tensor(self):
87+
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
88+
nvcv_img = F.to_nvcv_tensor(img_data)
89+
assert nvcv_img is not None
90+
91+
def test_4_channel_uint8_ndarray_to_nvcv_tensor(self):
92+
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
93+
nvcv_img = F.to_nvcv_tensor(img_data)
94+
assert nvcv_img is not None
95+
96+
def test_explicit_format_rgb8(self):
97+
img_data = torch.ByteTensor(3, 4, 4).random_(0, 255).cuda()
98+
nvcv_img = F.to_nvcv_tensor(img_data, format=nvcv.Format.RGB8)
99+
assert nvcv_img is not None
100+
101+
def test_explicit_format_bgr8(self):
102+
img_data = torch.ByteTensor(3, 4, 4).random_(0, 255).cuda()
103+
nvcv_img = F.to_nvcv_tensor(img_data, format=nvcv.Format.BGR8)
104+
assert nvcv_img is not None
105+
106+
def test_explicit_format_hsv8(self):
107+
img_data = torch.ByteTensor(3, 4, 4).random_(0, 255).cuda()
108+
# HSV8 should work for 3-channel images
109+
nvcv_img = F.to_nvcv_tensor(img_data, format=nvcv.Format.HSV8)
110+
assert nvcv_img is not None
111+
112+
def test_explicit_format_rgba8(self):
113+
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).cuda()
114+
nvcv_img = F.to_nvcv_tensor(img_data, format=nvcv.Format.RGBA8)
115+
assert nvcv_img is not None
116+
117+
def test_explicit_format_bgra8(self):
118+
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).cuda()
119+
# BGRA8 should work for 4-channel images
120+
nvcv_img = F.to_nvcv_tensor(img_data, format=nvcv.Format.BGRA8)
121+
assert nvcv_img is not None
122+
123+
def test_invalid_input_type(self):
124+
with pytest.raises(TypeError, match=r"pic should be Tensor or ndarray"):
125+
F.to_nvcv_tensor("invalid_input")
126+
127+
def test_invalid_dimensions(self):
128+
# Test 1D array (too few dimensions)
129+
with pytest.raises(ValueError, match=r"pic should be 2/3/4 dimensional"):
130+
F.to_nvcv_tensor(torch.ByteTensor(4).cuda())
131+
132+
# Test 5D array (too many dimensions)
133+
with pytest.raises(ValueError, match=r"pic should be 2/3/4 dimensional"):
134+
F.to_nvcv_tensor(torch.ByteTensor(1, 1, 3, 4, 4).cuda())
135+
136+
def test_too_many_channels(self):
137+
with pytest.raises(ValueError, match=r"pic should not have > 4 channels"):
138+
F.to_nvcv_tensor(torch.ByteTensor(5, 4, 4).random_(0, 255).cuda())
139+
140+
def test_unsupported_dtype_for_channels(self):
141+
# Float64 is not supported
142+
img_data = torch.DoubleTensor(3, 4, 4).uniform_().cuda()
143+
with pytest.raises(TypeError, match=r"Unsupported dtype"):
144+
F.to_nvcv_tensor(img_data)
145+
146+
147+
def make_nvcv_image(num_channels=3, dtype=torch.uint8):
148+
"""Helper function to create NVCV Tensor for testing"""
149+
if dtype == torch.uint8:
150+
img_data = torch.ByteTensor(num_channels, 4, 4).random_(0, 255).cuda()
151+
else:
152+
img_data = torch.Tensor(num_channels, 4, 4).uniform_().cuda()
153+
return F.to_nvcv_tensor(img_data)
154+
155+
156+
def transform_cls_to_functional(get_transform_cls):
157+
def wrapper(inpt):
158+
transform_cls = get_transform_cls()
159+
return transform_cls()(inpt)
160+
161+
return wrapper
162+
163+
164+
@pytest.mark.skipif(CVCUDA_AVAILABLE is False, reason="test requires CVCUDA")
165+
@pytest.mark.skipif(CUDA_AVAILABLE is False, reason="test requires CUDA")
166+
class TestNVCVToTensor:
167+
@pytest.mark.parametrize("num_channels", [1, 3, 4])
168+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
169+
@pytest.mark.parametrize(
170+
"fn",
171+
[
172+
"functional",
173+
"transform",
174+
],
175+
)
176+
def test_functional_and_transform(self, num_channels, dtype, fn):
177+
input = make_nvcv_image(num_channels=num_channels, dtype=dtype)
178+
179+
# Delay function reference until test execution time
180+
if fn == "functional":
181+
fn_ref = F.nvcv_to_tensor
182+
else: # fn == "transform"
183+
fn_ref = transform_cls_to_functional(
184+
lambda: __import__("torchvision.transforms.v2", fromlist=["NVCVToTensor"]).NVCVToTensor
185+
)
186+
187+
output = fn_ref(input)
188+
189+
assert isinstance(output, torch.Tensor)
190+
# Convert input to tensor to compare sizes
191+
input_tensor = F.nvcv_to_tensor(input)
192+
assert F.get_size(output) == F.get_size(input_tensor)
193+
194+
def test_functional_error(self):
195+
with pytest.raises(TypeError, match="nvcv_img should be NVCV Tensor"):
196+
F.nvcv_to_tensor(object())

torchvision/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,15 @@ def _is_tracing():
9999
return torch._C._get_tracing_state()
100100

101101

102+
def _is_cvcuda_available() -> bool:
103+
try:
104+
import cvcuda # type: ignore[import-not-found]
105+
import nvcv # type: ignore[import-not-found]
106+
except ImportError:
107+
return False
108+
return True
109+
110+
102111
def disable_beta_transforms_warning():
103112
# Noop, only exists to avoid breaking existing code.
104113
# See https://github.com/pytorch/vision/issues/7896

torchvision/transforms/v2/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,7 @@
5959
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size
6060

6161
from ._deprecated import ToTensor # usort: skip
62+
from torchvision import _is_cvcuda_available
63+
64+
if _is_cvcuda_available():
65+
from ._cvcuda import NVCVToTensor, ToNVCVTensor
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from torchvision.transforms.v2 import functional as F
2+
from torchvision.utils import _log_api_usage_once
3+
4+
5+
class ToNVCVTensor:
6+
"""Convert a tensor or an ndarray to NVCV Tensor
7+
8+
This transform does not support torchscript.
9+
10+
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
11+
H x W x C to an NVCV Tensor.
12+
13+
Args:
14+
format (`nvcv.Format`_): color format specification from nvcv.Format enum (optional).
15+
If ``format`` is ``None`` (default) the format is inferred from the input data:
16+
17+
- **1 channel images**: Inferred based on dtype
18+
- uint8 → U8, int16 → S16, int32 → S32, float32 → F32
19+
- **2 channel images**: float32 → _2F32 (only float32 is supported for 2-channel images)
20+
- **3 channel images**: Defaults to RGB-based formats
21+
- uint8 → RGB8, float32 → RGBf32
22+
- **4 channel images**: Defaults to RGBA-based formats
23+
- uint8 → RGBA8, float32 → RGBAf32
24+
25+
Explicit format examples: nvcv.Format.RGB8, nvcv.Format.BGR8, nvcv.Format.HSV8,
26+
nvcv.Format.RGBA8, nvcv.Format.BGRA8
27+
28+
.. _nvcv.Format: https://cvcuda.github.io/CV-CUDA/_python_api/nvcv/format.html
29+
"""
30+
31+
def __init__(self, format=None):
32+
_log_api_usage_once(self)
33+
self.format = format
34+
35+
def __call__(self, pic):
36+
"""
37+
Args:
38+
pic (Tensor or numpy.ndarray): Image to be converted to NVCV Tensor.
39+
40+
Returns:
41+
NVCV Tensor: Image converted to NVCV Tensor.
42+
43+
"""
44+
return F.to_nvcv_tensor(pic, self.format)
45+
46+
def __repr__(self) -> str:
47+
format_string = self.__class__.__name__ + "("
48+
if self.format is not None:
49+
format_string += f"format={self.format}"
50+
format_string += ")"
51+
return format_string
52+
53+
54+
class NVCVToTensor:
55+
"""Convert an NVCV Image to a tensor of the same type - this does not scale values.
56+
57+
This transform does not support torchscript.
58+
59+
Converts an NVCV Image with H height, W width, and C channels to a PyTorch Tensor
60+
of shape (C x H x W). The conversion happens directly on GPU when the NVCV Image
61+
is stored on GPU, avoiding unnecessary data transfers.
62+
63+
Example:
64+
>>> import nvcv
65+
>>> import torchvision.transforms.v2 as T
66+
>>> # Create an NVCV Image (320x240 RGB)
67+
>>> nvcv_img = nvcv.Image(nvcv.Size2D(320, 240), nvcv.Format.RGB8)
68+
>>> tensor = T.NVCVToTensor()(nvcv_img)
69+
>>> print(tensor.shape)
70+
torch.Size([3, 240, 320])
71+
"""
72+
73+
def __init__(self) -> None:
74+
_log_api_usage_once(self)
75+
76+
def __call__(self, pic):
77+
"""
78+
Args:
79+
pic (nvcv.Image): NVCV Image to be converted to tensor.
80+
81+
Returns:
82+
Tensor: Converted image in CHW format.
83+
"""
84+
return F.nvcv_to_tensor(pic)
85+
86+
def __repr__(self) -> str:
87+
return f"{self.__class__.__name__}()"

torchvision/transforms/v2/functional/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,8 @@
165165
from ._type_conversion import pil_to_tensor, to_image, to_pil_image
166166

167167
from ._deprecated import get_image_size, to_tensor # usort: skip
168+
169+
from torchvision import _is_cvcuda_available
170+
171+
if _is_cvcuda_available():
172+
from ._cvcuda import nvcv_to_tensor, to_nvcv_tensor

0 commit comments

Comments
 (0)