Skip to content

Commit 54584dd

Browse files
restrict support to 3 channels images for _infer_nvcv_format
1 parent e520c9f commit 54584dd

File tree

2 files changed

+40
-103
lines changed

2 files changed

+40
-103
lines changed

test/test_transforms_v2.py

Lines changed: 20 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6743,27 +6743,6 @@ def test_functional_error(self):
67436743
class TestToNVCVTensor:
67446744
"""Tests for to_nvcv_tensor function following patterns from TestToPil"""
67456745

6746-
def test_1_channel_uint8_tensor_to_nvcv_tensor(self):
6747-
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.uint8, device="cuda")
6748-
nvcv_img = F.to_nvcv_tensor(img_data)
6749-
# Check that the conversion succeeded and format is correct
6750-
assert nvcv_img is not None
6751-
6752-
def test_1_channel_int16_tensor_to_nvcv_tensor(self):
6753-
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int16, device="cuda")
6754-
nvcv_img = F.to_nvcv_tensor(img_data)
6755-
assert nvcv_img is not None
6756-
6757-
def test_1_channel_int32_tensor_to_nvcv_tensor(self):
6758-
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int32, device="cuda")
6759-
nvcv_img = F.to_nvcv_tensor(img_data)
6760-
assert nvcv_img is not None
6761-
6762-
def test_1_channel_float32_tensor_to_nvcv_tensor(self):
6763-
img_data = torch.rand(1, 4, 4, device="cuda")
6764-
nvcv_img = F.to_nvcv_tensor(img_data)
6765-
assert nvcv_img is not None
6766-
67676746
def test_3_channel_uint8_tensor_to_nvcv_tensor(self):
67686747
img_data = torch.randint(0, 256, (3, 4, 4), dtype=torch.uint8, device="cuda")
67696748
nvcv_img = F.to_nvcv_tensor(img_data)
@@ -6775,19 +6754,19 @@ def test_3_channel_float32_tensor_to_nvcv_tensor(self):
67756754
assert nvcv_img is not None
67766755

67776756
def test_unsupported_num_channels(self):
6778-
# Test 2-channel image (CHW format: 2 channels x 5 height x 5 width)
6779-
img_data = torch.rand(2, 5, 5, device="cuda")
6780-
with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"):
6757+
# Test 1-channel image (not supported)
6758+
img_data = torch.rand(1, 5, 5, device="cuda")
6759+
with pytest.raises(ValueError, match="Only 3-channel RGB images are supported"):
67816760
F.to_nvcv_tensor(img_data)
67826761

6783-
# Test 4-channel image (CHW format: 4 channels x 5 height x 5 width)
6784-
img_data = torch.randint(0, 256, (4, 5, 5), dtype=torch.uint8, device="cuda")
6785-
with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"):
6762+
# Test 2-channel image (not supported)
6763+
img_data = torch.rand(2, 5, 5, device="cuda")
6764+
with pytest.raises(ValueError, match="Only 3-channel RGB images are supported"):
67866765
F.to_nvcv_tensor(img_data)
67876766

6788-
# Test 5-channel image (CHW format: 5 channels x 5 height x 5 width)
6789-
img_data = torch.randint(0, 256, (5, 5, 5), dtype=torch.uint8, device="cuda")
6790-
with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"):
6767+
# Test 4-channel image (not supported)
6768+
img_data = torch.randint(0, 256, (4, 5, 5), dtype=torch.uint8, device="cuda")
6769+
with pytest.raises(ValueError, match="Only 3-channel RGB images are supported"):
67916770
F.to_nvcv_tensor(img_data)
67926771

67936772
def test_invalid_input_type(self):
@@ -6807,30 +6786,19 @@ def test_invalid_dimensions(self):
68076786
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
68086787
F.to_nvcv_tensor(torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8, device="cuda"))
68096788

6810-
def test_float64_tensor_to_nvcv_tensor(self):
6811-
# Test single channel float64 (F64 format is supported)
6812-
img_data = torch.rand(1, 4, 4, dtype=torch.float64, device="cuda")
6813-
nvcv_img = F.to_nvcv_tensor(img_data)
6814-
assert nvcv_img is not None
6815-
68166789
def test_float64_rgb_not_supported(self):
68176790
# Test 3-channel float64 is NOT supported (no RGBf64 format in CV-CUDA)
68186791
img_data = torch.rand(3, 4, 4, dtype=torch.float64, device="cuda")
68196792
with pytest.raises(TypeError, match=r"Unsupported dtype"):
68206793
F.to_nvcv_tensor(img_data)
68216794

6822-
@pytest.mark.parametrize("num_channels", [1, 3])
6823-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])
6824-
def test_round_trip(self, num_channels, dtype):
6825-
# Skip float64 for 3-channel (not supported by CV-CUDA)
6826-
if num_channels == 3 and dtype == torch.float64:
6827-
pytest.skip("float64 is not supported for 3-channel RGB images")
6828-
6829-
# Setup: Create a tensor in CHW format (PyTorch standard)
6795+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
6796+
def test_round_trip(self, dtype):
6797+
# Setup: Create a 3-channel tensor in CHW format (PyTorch standard)
68306798
if dtype == torch.uint8:
6831-
original_tensor = torch.randint(0, 256, (num_channels, 4, 4), dtype=dtype, device="cuda")
6799+
original_tensor = torch.randint(0, 256, (3, 4, 4), dtype=dtype, device="cuda")
68326800
else:
6833-
original_tensor = torch.rand(num_channels, 4, 4, dtype=dtype, device="cuda")
6801+
original_tensor = torch.rand(3, 4, 4, dtype=dtype, device="cuda")
68346802

68356803
# Execute: Convert to NVCV and back to tensor
68366804
# CHW -> (to_nvcv_tensor) -> NVCV NHWC -> (nvcv_to_tensor) -> CHW
@@ -6841,19 +6809,14 @@ def test_round_trip(self, num_channels, dtype):
68416809
# Use allclose for robust comparison that handles floating-point precision
68426810
assert torch.allclose(result_tensor, original_tensor, rtol=1e-5, atol=1e-7)
68436811

6844-
@pytest.mark.parametrize("num_channels", [1, 3])
6845-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])
6812+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
68466813
@pytest.mark.parametrize("batch_size", [1, 2, 4])
6847-
def test_round_trip_batched(self, num_channels, dtype, batch_size):
6848-
# Skip float64 for 3-channel (not supported by CV-CUDA)
6849-
if num_channels == 3 and dtype == torch.float64:
6850-
pytest.skip("float64 is not supported for 3-channel RGB images")
6851-
6852-
# Setup: Create a batched tensor in NCHW format
6814+
def test_round_trip_batched(self, dtype, batch_size):
6815+
# Setup: Create a batched 3-channel tensor in NCHW format
68536816
if dtype == torch.uint8:
6854-
original_tensor = torch.randint(0, 256, (batch_size, num_channels, 4, 4), dtype=dtype, device="cuda")
6817+
original_tensor = torch.randint(0, 256, (batch_size, 3, 4, 4), dtype=dtype, device="cuda")
68556818
else:
6856-
original_tensor = torch.rand(batch_size, num_channels, 4, 4, dtype=dtype, device="cuda")
6819+
original_tensor = torch.rand(batch_size, 3, 4, 4, dtype=dtype, device="cuda")
68576820

68586821
# Execute: Convert to NVCV and back to tensor
68596822
# NCHW -> (to_nvcv_tensor) -> NVCV NHWC -> (nvcv_to_tensor) -> NCHW
@@ -6870,7 +6833,7 @@ def test_round_trip_batched(self, num_channels, dtype, batch_size):
68706833
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
68716834
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
68726835
class TestNVCVToTensor:
6873-
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
6836+
@pytest.mark.parametrize("color_space", ["RGB"])
68746837
@pytest.mark.parametrize(
68756838
"fn",
68766839
[F.nvcv_to_tensor, transform_cls_to_functional(transforms.NVCVToTensor)],

torchvision/transforms/v2/functional/_type_conversion.py

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -36,47 +36,31 @@ def _infer_nvcv_format(img_tensor: torch.Tensor):
3636
3737
Args:
3838
img_tensor: Tensor with shape (H, W, C) where C is number of channels.
39+
Only 3-channel RGB images are supported.
3940
4041
Returns:
41-
tuple: (nvcv_format, processed_tensor) where processed_tensor may have reduced dimensions
42-
for single channel images.
42+
tuple: (nvcv_format, processed_tensor)
4343
4444
Raises:
45-
TypeError: If dtype is not supported for the given number of channels.
46-
ValueError: If number of channels is not 1 or 3.
45+
TypeError: If dtype is not supported.
46+
ValueError: If number of channels is not 3.
4747
"""
4848
import nvcv # type: ignore[import-not-found]
4949

5050
num_channels = img_tensor.shape[2]
5151
dtype = img_tensor.dtype
5252

53-
# Handle single channel images
54-
if num_channels == 1:
55-
img_tensor = img_tensor[:, :, 0]
56-
if dtype == torch.uint8:
57-
return nvcv.Format.U8, img_tensor
58-
elif dtype == torch.int16:
59-
return nvcv.Format.S16, img_tensor
60-
elif dtype == torch.int32:
61-
return nvcv.Format.S32, img_tensor
62-
elif dtype == torch.float32:
63-
return nvcv.Format.F32, img_tensor
64-
elif dtype == torch.float64:
65-
return nvcv.Format.F64, img_tensor
66-
else:
67-
raise TypeError(f"Unsupported dtype {dtype} for single channel image")
68-
69-
# Handle 3 channel images (defaults to RGB)
70-
elif num_channels == 3:
71-
if dtype == torch.uint8:
72-
return nvcv.Format.RGB8, img_tensor
73-
elif dtype == torch.float32:
74-
return nvcv.Format.RGBf32, img_tensor
75-
else:
76-
# Note: CV-CUDA does not support float64 for RGB images (only F64 for single-channel)
77-
raise TypeError(f"Unsupported dtype {dtype} for 3-channel image")
78-
79-
raise ValueError(f"Only 1 and 3 channel images are supported. Got {num_channels} channels.")
53+
# Validate number of channels upfront
54+
if num_channels != 3:
55+
raise ValueError(f"Only 3-channel RGB images are supported. Got {num_channels} channels.")
56+
57+
# Handle 3 channel RGB images
58+
if dtype == torch.uint8:
59+
return nvcv.Format.RGB8, img_tensor
60+
elif dtype == torch.float32:
61+
return nvcv.Format.RGBf32, img_tensor
62+
else:
63+
raise TypeError(f"Unsupported dtype {dtype} for RGB images. Only uint8 and float32 are supported.")
8064

8165

8266
@torch.jit.unused
@@ -88,7 +72,7 @@ def to_nvcv_tensor(pic) -> "nvcv.Tensor":
8872
Args:
8973
pic (torch.Tensor): Image to be converted to nvcv.Tensor.
9074
Tensor can be in CHW format (unbatched) or NCHW format (batched).
91-
Only 1-channel and 3-channel images are supported.
75+
Only 3-channel RGB images are supported.
9276
9377
Returns:
9478
nvcv.Tensor: Image converted to nvcv.Tensor with NHWC layout.
@@ -116,22 +100,12 @@ def to_nvcv_tensor(pic) -> "nvcv.Tensor":
116100
# Convert NCHW -> NHWC
117101
img_tensor = img_tensor.permute(0, 2, 3, 1)
118102

119-
# Infer format from the first image
103+
# Infer format from the first image - this validates we have 3 channels
120104
sample_img = img_tensor[0]
121-
_, sample_img = _infer_nvcv_format(sample_img)
122-
123-
# If format inference removed channel dimension (single channel case)
124-
# apply the same transformation to all images
125-
if sample_img.ndim == 2:
126-
# Batched single channel case: remove channel dimension
127-
img_tensor = img_tensor.squeeze(-1)
128-
layout = nvcv.TensorLayout.NHW
129-
else:
130-
# Batched multi-channel
131-
layout = nvcv.TensorLayout.NHWC
105+
_infer_nvcv_format(sample_img)
132106

133-
# Convert to NVCV tensor with the appropriate layout
134-
return cvcuda.as_tensor(img_tensor.cuda().contiguous(), layout)
107+
# Convert to NVCV tensor with NHWC layout (always multi-channel RGB)
108+
return cvcuda.as_tensor(img_tensor.cuda().contiguous(), nvcv.TensorLayout.NHWC)
135109

136110

137111
@torch.jit.unused

0 commit comments

Comments
 (0)