Skip to content

Commit df3a91b

Browse files
add helper function _import_cvcuda_modules for better logging
1 parent 7371b2e commit df3a91b

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

torchvision/transforms/v2/functional/_type_conversion.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,29 @@
1111
import nvcv # type: ignore[import-not-found]
1212

1313

14+
def _import_cvcuda_modules():
15+
"""Import CV-CUDA modules with informative error message if not installed.
16+
17+
Returns:
18+
tuple: (cvcuda, nvcv) modules.
19+
20+
Raises:
21+
RuntimeError: If CV-CUDA is not installed.
22+
"""
23+
try:
24+
import cvcuda # type: ignore[import-not-found]
25+
import nvcv # type: ignore[import-not-found]
26+
27+
return cvcuda, nvcv
28+
except ImportError as e:
29+
raise RuntimeError(
30+
"CV-CUDA is required but not installed. "
31+
"Please install it following the instructions at "
32+
"https://github.com/CVCUDA/CV-CUDA or via pip: "
33+
"`pip install cvcuda-cu12` (for CUDA 12) or `pip install cvcuda-cu11` (for CUDA 11)."
34+
) from e
35+
36+
1437
@torch.jit.unused
1538
def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tensors.Image:
1639
"""See :class:`~torchvision.transforms.v2.ToImage` for details."""
@@ -45,7 +68,7 @@ def _infer_nvcv_format(img_tensor: torch.Tensor):
4568
TypeError: If dtype is not supported.
4669
ValueError: If number of channels is not 3.
4770
"""
48-
import nvcv # type: ignore[import-not-found]
71+
_, nvcv = _import_cvcuda_modules()
4972

5073
num_channels = img_tensor.shape[2]
5174
dtype = img_tensor.dtype
@@ -77,8 +100,7 @@ def to_nvcv_tensor(pic) -> "nvcv.Tensor":
77100
Returns:
78101
nvcv.Tensor: Image converted to nvcv.Tensor with NHWC layout.
79102
"""
80-
import cvcuda # type: ignore[import-not-found]
81-
import nvcv # type: ignore[import-not-found]
103+
cvcuda, nvcv = _import_cvcuda_modules()
82104

83105
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
84106
_log_api_usage_once(to_nvcv_tensor)
@@ -119,7 +141,7 @@ def nvcv_to_tensor(nvcv_img: "nvcv.Tensor") -> torch.Tensor:
119141
Returns:
120142
torch.Tensor: Converted image in NCHW format (batched).
121143
"""
122-
import nvcv # type: ignore[import-not-found]
144+
_, nvcv = _import_cvcuda_modules()
123145

124146
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
125147
_log_api_usage_once(nvcv_to_tensor)

0 commit comments

Comments
 (0)