Skip to content

Commit 966b53a

Browse files
rename nvcv as cvcuda
1 parent 73ca541 commit 966b53a

File tree

7 files changed

+154
-124
lines changed

7 files changed

+154
-124
lines changed

test/common_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2121
from torchvision import io, tv_tensors
2222
from torchvision.transforms._functional_tensor import _max_value as get_max_value
23-
from torchvision.transforms.v2.functional import to_image, to_nvcv_tensor, to_pil_image
23+
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
2424
from torchvision.utils import _Image_fromarray
2525

2626

@@ -400,8 +400,8 @@ def make_image_pil(*args, **kwargs):
400400
return to_pil_image(make_image(*args, **kwargs))
401401

402402

403-
def make_image_nvcv(*args, **kwargs):
404-
return to_nvcv_tensor(make_image(*args, **kwargs))
403+
def make_image_cvcuda(*args, **kwargs):
404+
return to_cvcuda_tensor(make_image(*args, **kwargs))
405405

406406

407407
def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):

test/test_transforms_v2.py

Lines changed: 95 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
make_bounding_boxes,
3030
make_detection_masks,
3131
make_image,
32-
make_image_nvcv,
32+
make_image_cvcuda,
3333
make_image_pil,
3434
make_image_tensor,
3535
make_keypoints,
@@ -6740,79 +6740,105 @@ def test_functional_error(self):
67406740

67416741
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
67426742
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
6743-
class TestToNVCVTensor:
6744-
"""Tests for to_nvcv_tensor function following patterns from TestToPil"""
6745-
6746-
def test_1_channel_uint8_tensor_to_nvcv_tensor(self):
6747-
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.uint8, device="cuda")
6748-
nvcv_img = F.to_nvcv_tensor(img_data)
6743+
class TestToCVCUDATensor:
6744+
"""Tests for to_cvcuda_tensor function following patterns from TestToPil"""
6745+
6746+
def test_1_channel_uint8_tensor_to_cvcuda_tensor(self):
6747+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6748+
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.uint8)
6749+
img_data = img_data.cuda()
6750+
cvcuda_img = F.to_cvcuda_tensor(img_data)
67496751
# Check that the conversion succeeded and format is correct
6750-
assert nvcv_img is not None
6751-
6752-
def test_1_channel_int16_tensor_to_nvcv_tensor(self):
6753-
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int16, device="cuda")
6754-
nvcv_img = F.to_nvcv_tensor(img_data)
6755-
assert nvcv_img is not None
6756-
6757-
def test_1_channel_int32_tensor_to_nvcv_tensor(self):
6758-
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int32, device="cuda")
6759-
nvcv_img = F.to_nvcv_tensor(img_data)
6760-
assert nvcv_img is not None
6761-
6762-
def test_1_channel_float32_tensor_to_nvcv_tensor(self):
6763-
img_data = torch.rand(1, 4, 4, device="cuda")
6764-
nvcv_img = F.to_nvcv_tensor(img_data)
6765-
assert nvcv_img is not None
6766-
6767-
def test_3_channel_uint8_tensor_to_nvcv_tensor(self):
6768-
img_data = torch.randint(0, 256, (3, 4, 4), dtype=torch.uint8, device="cuda")
6769-
nvcv_img = F.to_nvcv_tensor(img_data)
6770-
assert nvcv_img is not None
6771-
6772-
def test_3_channel_float32_tensor_to_nvcv_tensor(self):
6773-
img_data = torch.rand(3, 4, 4, device="cuda")
6774-
nvcv_img = F.to_nvcv_tensor(img_data)
6775-
assert nvcv_img is not None
6752+
assert cvcuda_img is not None
6753+
6754+
def test_1_channel_int16_tensor_to_cvcuda_tensor(self):
6755+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6756+
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int16)
6757+
img_data = img_data.cuda()
6758+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6759+
assert cvcuda_img is not None
6760+
6761+
def test_1_channel_int32_tensor_to_cvcuda_tensor(self):
6762+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6763+
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int32)
6764+
img_data = img_data.cuda()
6765+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6766+
assert cvcuda_img is not None
6767+
6768+
def test_1_channel_float32_tensor_to_cvcuda_tensor(self):
6769+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6770+
img_data = torch.rand(1, 4, 4)
6771+
img_data = img_data.cuda()
6772+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6773+
assert cvcuda_img is not None
6774+
6775+
def test_3_channel_uint8_tensor_to_cvcuda_tensor(self):
6776+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6777+
img_data = torch.randint(0, 256, (3, 4, 4), dtype=torch.uint8)
6778+
img_data = img_data.cuda()
6779+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6780+
assert cvcuda_img is not None
6781+
6782+
def test_3_channel_float32_tensor_to_cvcuda_tensor(self):
6783+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6784+
img_data = torch.rand(3, 4, 4)
6785+
img_data = img_data.cuda()
6786+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6787+
assert cvcuda_img is not None
67766788

67776789
def test_unsupported_num_channels(self):
67786790
# Test 2-channel image (not supported)
6779-
img_data = torch.rand(2, 5, 5, device="cuda")
6791+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6792+
img_data = torch.rand(2, 5, 5)
6793+
img_data = img_data.cuda()
67806794
with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"):
6781-
F.to_nvcv_tensor(img_data)
6795+
F.to_cvcuda_tensor(img_data)
67826796

67836797
# Test 4-channel image (not supported)
6784-
img_data = torch.randint(0, 256, (4, 5, 5), dtype=torch.uint8, device="cuda")
6798+
img_data = torch.randint(0, 256, (4, 5, 5), dtype=torch.uint8)
6799+
img_data = img_data.cuda()
67856800
with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"):
6786-
F.to_nvcv_tensor(img_data)
6801+
F.to_cvcuda_tensor(img_data)
67876802

67886803
def test_invalid_input_type(self):
67896804
with pytest.raises(TypeError, match=r"pic should be `torch.Tensor`"):
6790-
F.to_nvcv_tensor("invalid_input")
6805+
F.to_cvcuda_tensor("invalid_input")
67916806

67926807
def test_invalid_dimensions(self):
67936808
# Test 1D array (too few dimensions)
6809+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
67946810
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6795-
F.to_nvcv_tensor(torch.randint(0, 256, (4,), dtype=torch.uint8, device="cuda"))
6811+
img_data = torch.randint(0, 256, (4,), dtype=torch.uint8)
6812+
img_data = img_data.cuda()
6813+
F.to_cvcuda_tensor(img_data)
67966814

67976815
# Test 2D array (no longer supported)
67986816
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6799-
F.to_nvcv_tensor(torch.randint(0, 256, (4, 4), dtype=torch.uint8, device="cuda"))
6817+
img_data = torch.randint(0, 256, (4, 4), dtype=torch.uint8)
6818+
img_data = img_data.cuda()
6819+
F.to_cvcuda_tensor(img_data)
68006820

68016821
# Test 5D array (too many dimensions)
68026822
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6803-
F.to_nvcv_tensor(torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8, device="cuda"))
6823+
img_data = torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8)
6824+
img_data = img_data.cuda()
6825+
F.to_cvcuda_tensor(img_data)
68046826

6805-
def test_float64_tensor_to_nvcv_tensor(self):
6827+
def test_float64_tensor_to_cvcuda_tensor(self):
68066828
# Test single channel float64 (F64 format is supported)
6807-
img_data = torch.rand(1, 4, 4, dtype=torch.float64, device="cuda")
6808-
nvcv_img = F.to_nvcv_tensor(img_data)
6809-
assert nvcv_img is not None
6829+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6830+
img_data = torch.rand(1, 4, 4, dtype=torch.float64)
6831+
img_data = img_data.cuda()
6832+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6833+
assert cvcuda_img is not None
68106834

68116835
def test_float64_rgb_not_supported(self):
68126836
# Test 3-channel float64 is NOT supported (no RGBf64 format in CV-CUDA)
6813-
img_data = torch.rand(3, 4, 4, dtype=torch.float64, device="cuda")
6837+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6838+
img_data = torch.rand(3, 4, 4, dtype=torch.float64)
6839+
img_data = img_data.cuda()
68146840
with pytest.raises(TypeError, match=r"Unsupported dtype"):
6815-
F.to_nvcv_tensor(img_data)
6841+
F.to_cvcuda_tensor(img_data)
68166842

68176843
@pytest.mark.parametrize("num_channels", [1, 3])
68186844
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])
@@ -6822,15 +6848,17 @@ def test_round_trip(self, num_channels, dtype):
68226848
pytest.skip("float64 is not supported for 3-channel RGB images")
68236849

68246850
# Setup: Create a tensor in CHW format (PyTorch standard)
6851+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
68256852
if dtype == torch.uint8:
6826-
original_tensor = torch.randint(0, 256, (num_channels, 4, 4), dtype=dtype, device="cuda")
6853+
original_tensor = torch.randint(0, 256, (num_channels, 4, 4), dtype=dtype)
68276854
else:
6828-
original_tensor = torch.rand(num_channels, 4, 4, dtype=dtype, device="cuda")
6855+
original_tensor = torch.rand(num_channels, 4, 4, dtype=dtype)
6856+
original_tensor = original_tensor.cuda()
68296857

6830-
# Execute: Convert to NVCV and back to tensor
6831-
# CHW -> (to_nvcv_tensor) -> NVCV NHWC -> (nvcv_to_tensor) -> NCHW
6832-
nvcv_tensor = F.to_nvcv_tensor(original_tensor)
6833-
result_tensor = F.nvcv_to_tensor(nvcv_tensor)
6858+
# Execute: Convert to CV-CUDA and back to tensor
6859+
# CHW -> (to_cvcuda_tensor) -> CV-CUDA NHWC -> (cvcuda_to_tensor) -> NCHW
6860+
cvcuda_tensor = F.to_cvcuda_tensor(original_tensor)
6861+
result_tensor = F.cvcuda_to_tensor(cvcuda_tensor)
68346862

68356863
# Remove batch dimension that was added during conversion since original was unbatched
68366864
result_tensor = result_tensor.squeeze(0)
@@ -6847,15 +6875,17 @@ def test_round_trip_batched(self, num_channels, dtype, batch_size):
68476875
pytest.skip("float64 is not supported for 3-channel RGB images")
68486876

68496877
# Setup: Create a batched tensor in NCHW format
6878+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
68506879
if dtype == torch.uint8:
6851-
original_tensor = torch.randint(0, 256, (batch_size, num_channels, 4, 4), dtype=dtype, device="cuda")
6880+
original_tensor = torch.randint(0, 256, (batch_size, num_channels, 4, 4), dtype=dtype)
68526881
else:
6853-
original_tensor = torch.rand(batch_size, num_channels, 4, 4, dtype=dtype, device="cuda")
6882+
original_tensor = torch.rand(batch_size, num_channels, 4, 4, dtype=dtype)
6883+
original_tensor = original_tensor.cuda()
68546884

6855-
# Execute: Convert to NVCV and back to tensor
6856-
# NCHW -> (to_nvcv_tensor) -> NVCV NHWC -> (nvcv_to_tensor) -> NCHW
6857-
nvcv_tensor = F.to_nvcv_tensor(original_tensor)
6858-
result_tensor = F.nvcv_to_tensor(nvcv_tensor)
6885+
# Execute: Convert to CV-CUDA and back to tensor
6886+
# NCHW -> (to_cvcuda_tensor) -> CV-CUDA NHWC -> (cvcuda_to_tensor) -> NCHW
6887+
cvcuda_tensor = F.to_cvcuda_tensor(original_tensor)
6888+
result_tensor = F.cvcuda_to_tensor(cvcuda_tensor)
68596889

68606890
# Assert: The round-trip conversion preserves the original batched tensor exactly
68616891
torch.testing.assert_close(result_tensor, original_tensor, rtol=0, atol=0)
@@ -6865,25 +6895,25 @@ def test_round_trip_batched(self, num_channels, dtype, batch_size):
68656895

68666896
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
68676897
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
6868-
class TestNVCVToTensor:
6898+
class TestCVDUDAToTensor:
68696899
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
68706900
@pytest.mark.parametrize(
68716901
"fn",
6872-
[F.nvcv_to_tensor, transform_cls_to_functional(transforms.NVCVToTensor)],
6902+
[F.cvcuda_to_tensor, transform_cls_to_functional(transforms.CVCUDAToTensor)],
68736903
)
68746904
def test_functional_and_transform(self, color_space, fn):
6875-
input = make_image_nvcv(color_space=color_space)
6905+
input = make_image_cvcuda(color_space=color_space)
68766906

68776907
output = fn(input)
68786908

68796909
assert isinstance(output, torch.Tensor)
68806910
# Convert input to tensor to compare sizes
6881-
input_tensor = F.nvcv_to_tensor(input)
6911+
input_tensor = F.cvcuda_to_tensor(input)
68826912
assert F.get_size(output) == F.get_size(input_tensor)
68836913

68846914
def test_functional_error(self):
6885-
with pytest.raises(TypeError, match="nvcv_img should be `nvcv.Tensor`"):
6886-
F.nvcv_to_tensor(object())
6915+
with pytest.raises(TypeError, match="cvcuda_img should be `cvcuda.Tensor`"):
6916+
F.cvcuda_to_tensor(object())
68876917

68886918

68896919
class TestLambda:

torchvision/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def _is_tracing():
102102
def _is_cvcuda_available() -> bool:
103103
try:
104104
import cvcuda # type: ignore[import-not-found]
105-
import nvcv # type: ignore[import-not-found]
106105

107106
return True
108107
except ImportError:

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
ToDtype,
5656
)
5757
from ._temporal import UniformTemporalSubsample
58-
from ._type_conversion import NVCVToTensor, PILToTensor, ToImage, ToNVCVTensor, ToPILImage, ToPureTensor
58+
from ._type_conversion import CVCUDAToTensor, PILToTensor, ToCVCUDATensor, ToImage, ToPILImage, ToPureTensor
5959
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size
6060

6161
from ._deprecated import ToTensor # usort: skip

torchvision/transforms/v2/_type_conversion.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> torch.Tensor:
9292
return inpt.as_subclass(torch.Tensor)
9393

9494

95-
class ToNVCVTensor:
96-
"""Convert a torch.Tensor to nvcv.Tensor
95+
class ToCVCUDATensor:
96+
"""Convert a torch.Tensor to cvcuda.Tensor
9797
9898
This transform does not support torchscript.
9999
100-
Converts a torch.*Tensor of shape C x H x W to a nvcv.Tensor.
100+
Converts a torch.*Tensor of shape C x H x W to a cvcuda.Tensor.
101101
Only 1-channel and 3-channel images are supported.
102102
"""
103103

@@ -107,38 +107,40 @@ def __init__(self):
107107
def __call__(self, pic):
108108
"""
109109
Args:
110-
pic (torch.Tensor): Image to be converted to nvcv.Tensor.
110+
pic (torch.Tensor): Image to be converted to cvcuda.Tensor.
111111
112112
Returns:
113-
nvcv.Tensor: Image converted to nvcv.Tensor.
113+
cvcuda.Tensor: Image converted to cvcuda.Tensor.
114114
115115
"""
116-
return F.to_nvcv_tensor(pic)
116+
return F.to_cvcuda_tensor(pic)
117117

118118
def __repr__(self) -> str:
119119
return f"{self.__class__.__name__}()"
120120

121121

122-
class NVCVToTensor:
123-
"""Convert a `nvcv.Tensor` to a `torch.Tensor` of the same type - this does not scale values.
122+
class CVCUDAToTensor:
123+
"""Convert a `cvcuda.Tensor` to a `torch.Tensor` of the same type - this does not scale values.
124124
125125
This transform does not support torchscript.
126126
127-
Converts a `nvcv.Tensor` to a `torch.Tensor`. Supports both batched and unbatched inputs:
127+
Converts a `cvcuda.Tensor` to a `torch.Tensor`. Supports both batched and unbatched inputs:
128128
- Unbatched: (H, W, C) or (H, W) → (C, H, W) or (1, H, W)
129129
- Batched: (N, H, W, C) or (N, H, W) → (N, C, H, W) or (N, 1, H, W)
130130
131-
The conversion happens directly on GPU when the `nvcv.Tensor` is stored on GPU,
131+
The conversion happens directly on GPU when the `cvcuda.Tensor` is stored on GPU,
132132
avoiding unnecessary data transfers.
133133
134134
Example:
135-
>>> import nvcv
135+
>>> import cvcuda
136136
>>> import torchvision.transforms.v2 as T
137-
>>> # Create an NVCV Image (320x240 RGB)
138-
>>> nvcv_img = nvcv.Image(nvcv.Size2D(320, 240), nvcv.Format.RGB8)
139-
>>> tensor = T.NVCVToTensor()(nvcv_img)
137+
>>> # Create a CV-CUDA tensor (320x240 RGB)
138+
>>> # Note: In CV-CUDA 0.16.0+, Image/Tensor creation uses cvcuda module
139+
>>> img_tensor = torch.randint(0, 255, (1, 240, 320, 3), dtype=torch.uint8, device="cuda")
140+
>>> cvcuda_tensor = cvcuda.as_tensor(img_tensor, cvcuda.TensorLayout.NHWC)
141+
>>> tensor = T.CVCUDAToTensor()(cvcuda_tensor)
140142
>>> print(tensor.shape)
141-
torch.Size([3, 240, 320])
143+
torch.Size([1, 3, 240, 320])
142144
"""
143145

144146
def __init__(self) -> None:
@@ -147,12 +149,12 @@ def __init__(self) -> None:
147149
def __call__(self, pic):
148150
"""
149151
Args:
150-
pic (nvcv.Image): NVCV Image to be converted to tensor.
152+
pic (cvcuda.Tensor): CV-CUDA Tensor to be converted to tensor.
151153
152154
Returns:
153155
Tensor: Converted image in CHW format.
154156
"""
155-
return F.nvcv_to_tensor(pic)
157+
return F.cvcuda_to_tensor(pic)
156158

157159
def __repr__(self) -> str:
158160
return f"{self.__class__.__name__}()"

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,6 @@
162162
to_dtype_video,
163163
)
164164
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
165-
from ._type_conversion import nvcv_to_tensor, pil_to_tensor, to_image, to_nvcv_tensor, to_pil_image
165+
from ._type_conversion import cvcuda_to_tensor, pil_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
166166

167167
from ._deprecated import get_image_size, to_tensor # usort: skip

0 commit comments

Comments
 (0)