Skip to content

Commit e520c9f

Browse files
add support for 1 channel float64
1 parent 6cb29af commit e520c9f

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

test/test_transforms_v2.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6807,15 +6807,25 @@ def test_invalid_dimensions(self):
68076807
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
68086808
F.to_nvcv_tensor(torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8, device="cuda"))
68096809

6810-
def test_unsupported_dtype_for_channels(self):
6811-
# Float64 is not supported
6810+
def test_float64_tensor_to_nvcv_tensor(self):
6811+
# Test single channel float64 (F64 format is supported)
6812+
img_data = torch.rand(1, 4, 4, dtype=torch.float64, device="cuda")
6813+
nvcv_img = F.to_nvcv_tensor(img_data)
6814+
assert nvcv_img is not None
6815+
6816+
def test_float64_rgb_not_supported(self):
6817+
# Test 3-channel float64 is NOT supported (no RGBf64 format in CV-CUDA)
68126818
img_data = torch.rand(3, 4, 4, dtype=torch.float64, device="cuda")
68136819
with pytest.raises(TypeError, match=r"Unsupported dtype"):
68146820
F.to_nvcv_tensor(img_data)
68156821

68166822
@pytest.mark.parametrize("num_channels", [1, 3])
6817-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
6823+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])
68186824
def test_round_trip(self, num_channels, dtype):
6825+
# Skip float64 for 3-channel (not supported by CV-CUDA)
6826+
if num_channels == 3 and dtype == torch.float64:
6827+
pytest.skip("float64 is not supported for 3-channel RGB images")
6828+
68196829
# Setup: Create a tensor in CHW format (PyTorch standard)
68206830
if dtype == torch.uint8:
68216831
original_tensor = torch.randint(0, 256, (num_channels, 4, 4), dtype=dtype, device="cuda")
@@ -6832,9 +6842,13 @@ def test_round_trip(self, num_channels, dtype):
68326842
assert torch.allclose(result_tensor, original_tensor, rtol=1e-5, atol=1e-7)
68336843

68346844
@pytest.mark.parametrize("num_channels", [1, 3])
6835-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
6845+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])
68366846
@pytest.mark.parametrize("batch_size", [1, 2, 4])
68376847
def test_round_trip_batched(self, num_channels, dtype, batch_size):
6848+
# Skip float64 for 3-channel (not supported by CV-CUDA)
6849+
if num_channels == 3 and dtype == torch.float64:
6850+
pytest.skip("float64 is not supported for 3-channel RGB images")
6851+
68386852
# Setup: Create a batched tensor in NCHW format
68396853
if dtype == torch.uint8:
68406854
original_tensor = torch.randint(0, 256, (batch_size, num_channels, 4, 4), dtype=dtype, device="cuda")

torchvision/transforms/v2/functional/_type_conversion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def _infer_nvcv_format(img_tensor: torch.Tensor):
6161
return nvcv.Format.S32, img_tensor
6262
elif dtype == torch.float32:
6363
return nvcv.Format.F32, img_tensor
64+
elif dtype == torch.float64:
65+
return nvcv.Format.F64, img_tensor
6466
else:
6567
raise TypeError(f"Unsupported dtype {dtype} for single channel image")
6668

@@ -71,6 +73,7 @@ def _infer_nvcv_format(img_tensor: torch.Tensor):
7173
elif dtype == torch.float32:
7274
return nvcv.Format.RGBf32, img_tensor
7375
else:
76+
# Note: CV-CUDA does not support float64 for RGB images (only F64 for single-channel)
7477
raise TypeError(f"Unsupported dtype {dtype} for 3-channel image")
7578

7679
raise ValueError(f"Only 1 and 3 channel images are supported. Got {num_channels} channels.")

0 commit comments

Comments
 (0)