Skip to content

Commit 6cb29af

Browse files
restrict support to 3 channels images
1 parent 2f4d875 commit 6cb29af

File tree

2 files changed

+32
-71
lines changed

2 files changed

+32
-71
lines changed

test/test_transforms_v2.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6774,16 +6774,6 @@ def test_3_channel_float32_tensor_to_nvcv_tensor(self):
67746774
nvcv_img = F.to_nvcv_tensor(img_data)
67756775
assert nvcv_img is not None
67766776

6777-
def test_2d_uint8_tensor_to_nvcv_tensor(self):
6778-
img_data = torch.randint(0, 256, (4, 4), dtype=torch.uint8, device="cuda")
6779-
nvcv_img = F.to_nvcv_tensor(img_data)
6780-
assert nvcv_img is not None
6781-
6782-
def test_2d_float32_tensor_to_nvcv_tensor(self):
6783-
img_data = torch.rand(4, 4, device="cuda")
6784-
nvcv_img = F.to_nvcv_tensor(img_data)
6785-
assert nvcv_img is not None
6786-
67876777
def test_unsupported_num_channels(self):
67886778
# Test 2-channel image (CHW format: 2 channels x 5 height x 5 width)
67896779
img_data = torch.rand(2, 5, 5, device="cuda")
@@ -6806,11 +6796,15 @@ def test_invalid_input_type(self):
68066796

68076797
def test_invalid_dimensions(self):
68086798
# Test 1D array (too few dimensions)
6809-
with pytest.raises(ValueError, match=r"pic should be 2/3/4 dimensional"):
6799+
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
68106800
F.to_nvcv_tensor(torch.randint(0, 256, (4,), dtype=torch.uint8, device="cuda"))
68116801

6802+
# Test 2D array (no longer supported)
6803+
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6804+
F.to_nvcv_tensor(torch.randint(0, 256, (4, 4), dtype=torch.uint8, device="cuda"))
6805+
68126806
# Test 5D array (too many dimensions)
6813-
with pytest.raises(ValueError, match=r"pic should be 2/3/4 dimensional"):
6807+
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
68146808
F.to_nvcv_tensor(torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8, device="cuda"))
68156809

68166810
def test_unsupported_dtype_for_channels(self):

torchvision/transforms/v2/functional/_type_conversion.py

Lines changed: 26 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -100,51 +100,32 @@ def to_nvcv_tensor(pic) -> "nvcv.Tensor":
100100
if not isinstance(pic, torch.Tensor):
101101
raise TypeError(f"pic should be `torch.Tensor`. Got {type(pic)}.")
102102

103-
# Handle different tensor formats and track if input was batched (NCHW) or unbatched (CHW/HW)
104-
if pic.ndim == 4:
105-
# Batched tensor in NCHW format, permute to NHWC
106-
img_tensor = pic.permute(0, 2, 3, 1)
107-
input_was_batched = True
108-
elif pic.ndim == 3:
109-
# Unbatched tensor in CHW format, permute to HWC
110-
img_tensor = pic.permute(1, 2, 0)
111-
input_was_batched = False
112-
else:
113-
# 2D or other formats (unbatched single-channel)
103+
# Validate dimensions - only support 3D (CHW) or 4D (NCHW)
104+
if pic.ndim == 3:
105+
# Add fake batch dimension to make it 4D
106+
img_tensor = pic.unsqueeze(0)
107+
elif pic.ndim == 4:
114108
img_tensor = pic
115-
input_was_batched = False
116-
117-
# Ensure image has channel dimension for unbatched case
118-
if img_tensor.ndim == 2:
119-
img_tensor = img_tensor.unsqueeze(2) # H W -> H W C
109+
else:
110+
raise ValueError(f"pic should be 3 or 4 dimensional. Got {pic.ndim} dimensions.")
120111

121-
# Validate dimensions
122-
if img_tensor.ndim not in (3, 4):
123-
raise ValueError(f"pic should be 2/3/4 dimensional. Got {img_tensor.ndim} dimensions.")
112+
# At this point, img_tensor is always 4D in NCHW format
113+
# Convert NCHW -> NHWC
114+
img_tensor = img_tensor.permute(0, 2, 3, 1)
124115

125-
# For batched inputs, use the first image to infer format
126-
sample_img = img_tensor[0] if img_tensor.ndim == 4 else img_tensor
116+
# Infer format from the first image
117+
sample_img = img_tensor[0]
127118
_, sample_img = _infer_nvcv_format(sample_img)
128119

129-
# If format inference modified the tensor (e.g., removed channel dimension for single channel)
120+
# If format inference removed channel dimension (single channel case)
130121
# apply the same transformation to all images
131-
if sample_img.ndim == 2 and img_tensor.ndim == 4:
122+
if sample_img.ndim == 2:
132123
# Batched single channel case: remove channel dimension
133124
img_tensor = img_tensor.squeeze(-1)
134-
elif sample_img.ndim == 2 and img_tensor.ndim == 3:
135-
# Unbatched single channel case: replace with 2D tensor
136-
img_tensor = sample_img
137-
138-
# Add batch dimension if not present (NVCV expects batched tensors)
139-
if not input_was_batched:
140-
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension at index 0
141-
142-
# Determine layout based on final tensor shape
143-
# After all transformations, tensor is either NHW (single-channel) or NHWC (multi-channel)
144-
if img_tensor.ndim == 3:
145-
layout = nvcv.TensorLayout.NHW # Batched single-channel
146-
else: # img_tensor.ndim == 4
147-
layout = nvcv.TensorLayout.NHWC # Batched multi-channel
125+
layout = nvcv.TensorLayout.NHW
126+
else:
127+
# Batched multi-channel
128+
layout = nvcv.TensorLayout.NHWC
148129

149130
# Convert to NVCV tensor with the appropriate layout
150131
return cvcuda.as_tensor(img_tensor.cuda().contiguous(), layout)
@@ -156,10 +137,10 @@ def nvcv_to_tensor(nvcv_img: "nvcv.Tensor") -> torch.Tensor:
156137
157138
Args:
158139
nvcv_img (nvcv.Tensor): nvcv.Tensor to be converted to PyTorch tensor.
159-
Expected to be in NHWC or NHW layout (for batched images) or HWC or HW layout (for unbatched).
140+
Expected to be in NHWC or NHW layout (batched images only).
160141
161142
Returns:
162-
torch.Tensor: Converted image in CHW format (unbatched) or NCHW format (batched).
143+
torch.Tensor: Converted image in NCHW format (batched).
163144
"""
164145
import nvcv # type: ignore[import-not-found]
165146

@@ -174,31 +155,17 @@ def nvcv_to_tensor(nvcv_img: "nvcv.Tensor") -> torch.Tensor:
174155
# NVCV tensors expose __cuda_array_interface__ which PyTorch can consume directly
175156
cuda_tensor = torch.as_tensor(nvcv_img.cuda(), device="cuda")
176157

177-
# Handle different dimensionalities
178-
# NVCV stores images in NHWC (batched multi-channel), NHW (batched single-channel),
179-
# HWC (unbatched multi-channel), or HW (unbatched single-channel) format
158+
# Only support 4D (NHWC) or 3D (NHW) batched tensors
159+
# NVCV stores images in NHWC (batched multi-channel) or NHW (batched single-channel) format
180160
if cuda_tensor.ndim == 4:
181161
# Batched multi-channel image in NHWC format
182162
# Convert NHWC -> NCHW
183163
img = cuda_tensor.permute(0, 3, 1, 2).contiguous()
184164
elif cuda_tensor.ndim == 3:
185-
# Could be either:
186-
# 1. Unbatched multi-channel (HWC) - last dim is 1 or 3
187-
# 2. Batched single-channel (NHW) - last dim is width
188-
# We distinguish by checking if last dimension is 1 or 3 (our supported channel counts)
189-
if cuda_tensor.shape[2] in (1, 3):
190-
# Unbatched multi-channel image in HWC format
191-
# Convert HWC -> CHW
192-
img = cuda_tensor.permute(2, 0, 1).contiguous()
193-
else:
194-
# Batched single-channel image in NHW format
195-
# Convert NHW -> NCHW by adding channel dimension
196-
img = cuda_tensor.unsqueeze(1).contiguous()
197-
elif cuda_tensor.ndim == 2:
198-
# Unbatched single-channel image in HW format
199-
# Convert HW -> CHW by adding channel dimension
200-
img = cuda_tensor.unsqueeze(0).contiguous()
165+
# Batched single-channel image in NHW format
166+
# Convert NHW -> NCHW by adding channel dimension
167+
img = cuda_tensor.unsqueeze(1).contiguous()
201168
else:
202-
raise ValueError(f"Image should be 2/3/4 dimensional. Got {cuda_tensor.ndim} dimensions.")
169+
raise ValueError(f"Image should be 3 or 4 dimensional. Got {cuda_tensor.ndim} dimensions.")
203170

204171
return img

0 commit comments

Comments
 (0)