@@ -100,51 +100,32 @@ def to_nvcv_tensor(pic) -> "nvcv.Tensor":
100100 if not isinstance (pic , torch .Tensor ):
101101 raise TypeError (f"pic should be `torch.Tensor`. Got { type (pic )} ." )
102102
103- # Handle different tensor formats and track if input was batched (NCHW) or unbatched (CHW/HW)
104- if pic .ndim == 4 :
105- # Batched tensor in NCHW format, permute to NHWC
106- img_tensor = pic .permute (0 , 2 , 3 , 1 )
107- input_was_batched = True
108- elif pic .ndim == 3 :
109- # Unbatched tensor in CHW format, permute to HWC
110- img_tensor = pic .permute (1 , 2 , 0 )
111- input_was_batched = False
112- else :
113- # 2D or other formats (unbatched single-channel)
103+ # Validate dimensions - only support 3D (CHW) or 4D (NCHW)
104+ if pic .ndim == 3 :
105+ # Add fake batch dimension to make it 4D
106+ img_tensor = pic .unsqueeze (0 )
107+ elif pic .ndim == 4 :
114108 img_tensor = pic
115- input_was_batched = False
116-
117- # Ensure image has channel dimension for unbatched case
118- if img_tensor .ndim == 2 :
119- img_tensor = img_tensor .unsqueeze (2 ) # H W -> H W C
109+ else :
110+ raise ValueError (f"pic should be 3 or 4 dimensional. Got { pic .ndim } dimensions." )
120111
121- # Validate dimensions
122- if img_tensor . ndim not in ( 3 , 4 ):
123- raise ValueError ( f"pic should be 2/3/4 dimensional. Got { img_tensor . ndim } dimensions." )
112+ # At this point, img_tensor is always 4D in NCHW format
113+ # Convert NCHW -> NHWC
114+ img_tensor = img_tensor . permute ( 0 , 2 , 3 , 1 )
124115
125- # For batched inputs, use the first image to infer format
126- sample_img = img_tensor [0 ] if img_tensor . ndim == 4 else img_tensor
116+ # Infer format from the first image
117+ sample_img = img_tensor [0 ]
127118 _ , sample_img = _infer_nvcv_format (sample_img )
128119
129- # If format inference modified the tensor (e.g., removed channel dimension for single channel)
120+ # If format inference removed channel dimension ( single channel case )
130121 # apply the same transformation to all images
131- if sample_img .ndim == 2 and img_tensor . ndim == 4 :
122+ if sample_img .ndim == 2 :
132123 # Batched single channel case: remove channel dimension
133124 img_tensor = img_tensor .squeeze (- 1 )
134- elif sample_img .ndim == 2 and img_tensor .ndim == 3 :
135- # Unbatched single channel case: replace with 2D tensor
136- img_tensor = sample_img
137-
138- # Add batch dimension if not present (NVCV expects batched tensors)
139- if not input_was_batched :
140- img_tensor = img_tensor .unsqueeze (0 ) # Add batch dimension at index 0
141-
142- # Determine layout based on final tensor shape
143- # After all transformations, tensor is either NHW (single-channel) or NHWC (multi-channel)
144- if img_tensor .ndim == 3 :
145- layout = nvcv .TensorLayout .NHW # Batched single-channel
146- else : # img_tensor.ndim == 4
147- layout = nvcv .TensorLayout .NHWC # Batched multi-channel
125+ layout = nvcv .TensorLayout .NHW
126+ else :
127+ # Batched multi-channel
128+ layout = nvcv .TensorLayout .NHWC
148129
149130 # Convert to NVCV tensor with the appropriate layout
150131 return cvcuda .as_tensor (img_tensor .cuda ().contiguous (), layout )
@@ -156,10 +137,10 @@ def nvcv_to_tensor(nvcv_img: "nvcv.Tensor") -> torch.Tensor:
156137
157138 Args:
158139 nvcv_img (nvcv.Tensor): nvcv.Tensor to be converted to PyTorch tensor.
159- Expected to be in NHWC or NHW layout (for batched images) or HWC or HW layout (for unbatched ).
140+ Expected to be in NHWC or NHW layout (batched images only ).
160141
161142 Returns:
162- torch.Tensor: Converted image in CHW format (unbatched) or NCHW format (batched).
143+ torch.Tensor: Converted image in NCHW format (batched).
163144 """
164145 import nvcv # type: ignore[import-not-found]
165146
@@ -174,31 +155,17 @@ def nvcv_to_tensor(nvcv_img: "nvcv.Tensor") -> torch.Tensor:
174155 # NVCV tensors expose __cuda_array_interface__ which PyTorch can consume directly
175156 cuda_tensor = torch .as_tensor (nvcv_img .cuda (), device = "cuda" )
176157
177- # Handle different dimensionalities
178- # NVCV stores images in NHWC (batched multi-channel), NHW (batched single-channel),
179- # HWC (unbatched multi-channel), or HW (unbatched single-channel) format
158+ # Only support 4D (NHWC) or 3D (NHW) batched tensors
159+ # NVCV stores images in NHWC (batched multi-channel) or NHW (batched single-channel) format
180160 if cuda_tensor .ndim == 4 :
181161 # Batched multi-channel image in NHWC format
182162 # Convert NHWC -> NCHW
183163 img = cuda_tensor .permute (0 , 3 , 1 , 2 ).contiguous ()
184164 elif cuda_tensor .ndim == 3 :
185- # Could be either:
186- # 1. Unbatched multi-channel (HWC) - last dim is 1 or 3
187- # 2. Batched single-channel (NHW) - last dim is width
188- # We distinguish by checking if last dimension is 1 or 3 (our supported channel counts)
189- if cuda_tensor .shape [2 ] in (1 , 3 ):
190- # Unbatched multi-channel image in HWC format
191- # Convert HWC -> CHW
192- img = cuda_tensor .permute (2 , 0 , 1 ).contiguous ()
193- else :
194- # Batched single-channel image in NHW format
195- # Convert NHW -> NCHW by adding channel dimension
196- img = cuda_tensor .unsqueeze (1 ).contiguous ()
197- elif cuda_tensor .ndim == 2 :
198- # Unbatched single-channel image in HW format
199- # Convert HW -> CHW by adding channel dimension
200- img = cuda_tensor .unsqueeze (0 ).contiguous ()
165+ # Batched single-channel image in NHW format
166+ # Convert NHW -> NCHW by adding channel dimension
167+ img = cuda_tensor .unsqueeze (1 ).contiguous ()
201168 else :
202- raise ValueError (f"Image should be 2/3/ 4 dimensional. Got { cuda_tensor .ndim } dimensions." )
169+ raise ValueError (f"Image should be 3 or 4 dimensional. Got { cuda_tensor .ndim } dimensions." )
203170
204171 return img
0 commit comments