Skip to content

Commit 0ed9367

Browse files
restore support for gray scale images
1 parent df3a91b commit 0ed9367

File tree

2 files changed

+83
-36
lines changed

2 files changed

+83
-36
lines changed

test/test_transforms_v2.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6743,6 +6743,27 @@ def test_functional_error(self):
67436743
class TestToNVCVTensor:
67446744
"""Tests for to_nvcv_tensor function following patterns from TestToPil"""
67456745

6746+
def test_1_channel_uint8_tensor_to_nvcv_tensor(self):
6747+
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.uint8, device="cuda")
6748+
nvcv_img = F.to_nvcv_tensor(img_data)
6749+
# Check that the conversion succeeded and format is correct
6750+
assert nvcv_img is not None
6751+
6752+
def test_1_channel_int16_tensor_to_nvcv_tensor(self):
6753+
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int16, device="cuda")
6754+
nvcv_img = F.to_nvcv_tensor(img_data)
6755+
assert nvcv_img is not None
6756+
6757+
def test_1_channel_int32_tensor_to_nvcv_tensor(self):
6758+
img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int32, device="cuda")
6759+
nvcv_img = F.to_nvcv_tensor(img_data)
6760+
assert nvcv_img is not None
6761+
6762+
def test_1_channel_float32_tensor_to_nvcv_tensor(self):
6763+
img_data = torch.rand(1, 4, 4, device="cuda")
6764+
nvcv_img = F.to_nvcv_tensor(img_data)
6765+
assert nvcv_img is not None
6766+
67466767
def test_3_channel_uint8_tensor_to_nvcv_tensor(self):
67476768
img_data = torch.randint(0, 256, (3, 4, 4), dtype=torch.uint8, device="cuda")
67486769
nvcv_img = F.to_nvcv_tensor(img_data)
@@ -6754,19 +6775,14 @@ def test_3_channel_float32_tensor_to_nvcv_tensor(self):
67546775
assert nvcv_img is not None
67556776

67566777
def test_unsupported_num_channels(self):
6757-
# Test 1-channel image (not supported)
6758-
img_data = torch.rand(1, 5, 5, device="cuda")
6759-
with pytest.raises(ValueError, match="Only 3-channel RGB images are supported"):
6760-
F.to_nvcv_tensor(img_data)
6761-
67626778
# Test 2-channel image (not supported)
67636779
img_data = torch.rand(2, 5, 5, device="cuda")
6764-
with pytest.raises(ValueError, match="Only 3-channel RGB images are supported"):
6780+
with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"):
67656781
F.to_nvcv_tensor(img_data)
67666782

67676783
# Test 4-channel image (not supported)
67686784
img_data = torch.randint(0, 256, (4, 5, 5), dtype=torch.uint8, device="cuda")
6769-
with pytest.raises(ValueError, match="Only 3-channel RGB images are supported"):
6785+
with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"):
67706786
F.to_nvcv_tensor(img_data)
67716787

67726788
def test_invalid_input_type(self):
@@ -6786,19 +6802,30 @@ def test_invalid_dimensions(self):
67866802
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
67876803
F.to_nvcv_tensor(torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8, device="cuda"))
67886804

6805+
def test_float64_tensor_to_nvcv_tensor(self):
6806+
# Test single channel float64 (F64 format is supported)
6807+
img_data = torch.rand(1, 4, 4, dtype=torch.float64, device="cuda")
6808+
nvcv_img = F.to_nvcv_tensor(img_data)
6809+
assert nvcv_img is not None
6810+
67896811
def test_float64_rgb_not_supported(self):
67906812
# Test 3-channel float64 is NOT supported (no RGBf64 format in CV-CUDA)
67916813
img_data = torch.rand(3, 4, 4, dtype=torch.float64, device="cuda")
67926814
with pytest.raises(TypeError, match=r"Unsupported dtype"):
67936815
F.to_nvcv_tensor(img_data)
67946816

6795-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
6796-
def test_round_trip(self, dtype):
6797-
# Setup: Create a 3-channel tensor in CHW format (PyTorch standard)
6817+
@pytest.mark.parametrize("num_channels", [1, 3])
6818+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])
6819+
def test_round_trip(self, num_channels, dtype):
6820+
# Skip float64 for 3-channel (not supported by CV-CUDA)
6821+
if num_channels == 3 and dtype == torch.float64:
6822+
pytest.skip("float64 is not supported for 3-channel RGB images")
6823+
6824+
# Setup: Create a tensor in CHW format (PyTorch standard)
67986825
if dtype == torch.uint8:
6799-
original_tensor = torch.randint(0, 256, (3, 4, 4), dtype=dtype, device="cuda")
6826+
original_tensor = torch.randint(0, 256, (num_channels, 4, 4), dtype=dtype, device="cuda")
68006827
else:
6801-
original_tensor = torch.rand(3, 4, 4, dtype=dtype, device="cuda")
6828+
original_tensor = torch.rand(num_channels, 4, 4, dtype=dtype, device="cuda")
68026829

68036830
# Execute: Convert to NVCV and back to tensor
68046831
# CHW -> (to_nvcv_tensor) -> NVCV NHWC -> (nvcv_to_tensor) -> NCHW
@@ -6811,14 +6838,19 @@ def test_round_trip(self, dtype):
68116838
# Assert: The round-trip conversion preserves the original tensor exactly
68126839
torch.testing.assert_close(result_tensor, original_tensor, rtol=0, atol=0)
68136840

6814-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
6841+
@pytest.mark.parametrize("num_channels", [1, 3])
6842+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])
68156843
@pytest.mark.parametrize("batch_size", [1, 2, 4])
6816-
def test_round_trip_batched(self, dtype, batch_size):
6817-
# Setup: Create a batched 3-channel tensor in NCHW format
6844+
def test_round_trip_batched(self, num_channels, dtype, batch_size):
6845+
# Skip float64 for 3-channel (not supported by CV-CUDA)
6846+
if num_channels == 3 and dtype == torch.float64:
6847+
pytest.skip("float64 is not supported for 3-channel RGB images")
6848+
6849+
# Setup: Create a batched tensor in NCHW format
68186850
if dtype == torch.uint8:
6819-
original_tensor = torch.randint(0, 256, (batch_size, 3, 4, 4), dtype=dtype, device="cuda")
6851+
original_tensor = torch.randint(0, 256, (batch_size, num_channels, 4, 4), dtype=dtype, device="cuda")
68206852
else:
6821-
original_tensor = torch.rand(batch_size, 3, 4, 4, dtype=dtype, device="cuda")
6853+
original_tensor = torch.rand(batch_size, num_channels, 4, 4, dtype=dtype, device="cuda")
68226854

68236855
# Execute: Convert to NVCV and back to tensor
68246856
# NCHW -> (to_nvcv_tensor) -> NVCV NHWC -> (nvcv_to_tensor) -> NCHW
@@ -6834,7 +6866,7 @@ def test_round_trip_batched(self, dtype, batch_size):
68346866
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
68356867
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
68366868
class TestNVCVToTensor:
6837-
@pytest.mark.parametrize("color_space", ["RGB"])
6869+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
68386870
@pytest.mark.parametrize(
68396871
"fn",
68406872
[F.nvcv_to_tensor, transform_cls_to_functional(transforms.NVCVToTensor)],

torchvision/transforms/v2/functional/_type_conversion.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,31 +59,46 @@ def _infer_nvcv_format(img_tensor: torch.Tensor):
5959
6060
Args:
6161
img_tensor: Tensor with shape (H, W, C) where C is number of channels.
62-
Only 3-channel RGB images are supported.
6362
6463
Returns:
65-
tuple: (nvcv_format, processed_tensor)
64+
tuple: (nvcv_format, processed_tensor) where processed_tensor may have reduced dimensions
65+
for single channel images.
6666
6767
Raises:
68-
TypeError: If dtype is not supported.
69-
ValueError: If number of channels is not 3.
68+
TypeError: If dtype is not supported for the given number of channels.
69+
ValueError: If number of channels is not 1 or 3.
7070
"""
7171
_, nvcv = _import_cvcuda_modules()
7272

7373
num_channels = img_tensor.shape[2]
7474
dtype = img_tensor.dtype
7575

76-
# Validate number of channels upfront
77-
if num_channels != 3:
78-
raise ValueError(f"Only 3-channel RGB images are supported. Got {num_channels} channels.")
79-
80-
# Handle 3 channel RGB images
81-
if dtype == torch.uint8:
82-
return nvcv.Format.RGB8, img_tensor
83-
elif dtype == torch.float32:
84-
return nvcv.Format.RGBf32, img_tensor
85-
else:
86-
raise TypeError(f"Unsupported dtype {dtype} for RGB images. Only uint8 and float32 are supported.")
76+
# Handle single channel images
77+
if num_channels == 1:
78+
if dtype == torch.uint8:
79+
return nvcv.Format.U8, img_tensor
80+
elif dtype == torch.int16:
81+
return nvcv.Format.S16, img_tensor
82+
elif dtype == torch.int32:
83+
return nvcv.Format.S32, img_tensor
84+
elif dtype == torch.float32:
85+
return nvcv.Format.F32, img_tensor
86+
elif dtype == torch.float64:
87+
return nvcv.Format.F64, img_tensor
88+
else:
89+
raise TypeError(f"Unsupported dtype {dtype} for single channel image")
90+
91+
# Handle 3 channel images (defaults to RGB)
92+
elif num_channels == 3:
93+
if dtype == torch.uint8:
94+
return nvcv.Format.RGB8, img_tensor
95+
elif dtype == torch.float32:
96+
return nvcv.Format.RGBf32, img_tensor
97+
else:
98+
# Note: CV-CUDA does not support float64 for RGB images (only F64 for single-channel)
99+
raise TypeError(f"Unsupported dtype {dtype} for 3-channel image")
100+
101+
raise ValueError(f"Only 1 and 3 channel images are supported. Got {num_channels} channels.")
87102

88103

89104
@torch.jit.unused
@@ -95,7 +110,7 @@ def to_nvcv_tensor(pic) -> "nvcv.Tensor":
95110
Args:
96111
pic (torch.Tensor): Image to be converted to nvcv.Tensor.
97112
Tensor can be in CHW format (unbatched) or NCHW format (batched).
98-
Only 3-channel RGB images are supported.
113+
Only 1-channel and 3-channel images are supported.
99114
100115
Returns:
101116
nvcv.Tensor: Image converted to nvcv.Tensor with NHWC layout.
@@ -122,11 +137,11 @@ def to_nvcv_tensor(pic) -> "nvcv.Tensor":
122137
# Convert NCHW -> NHWC
123138
img_tensor = img_tensor.permute(0, 2, 3, 1)
124139

125-
# Infer format from the first image - this validates we have 3 channels
140+
# Infer format from the first image
126141
sample_img = img_tensor[0]
127142
_infer_nvcv_format(sample_img)
128143

129-
# Convert to NVCV tensor with NHWC layout (always multi-channel RGB)
144+
# Convert to NVCV tensor with NHWC layout
130145
return cvcuda.as_tensor(img_tensor.cuda().contiguous(), nvcv.TensorLayout.NHWC)
131146

132147

0 commit comments

Comments
 (0)