|
11 | 11 | import nvcv # type: ignore[import-not-found] |
12 | 12 |
|
13 | 13 |
|
| 14 | +def _import_cvcuda_modules(): |
| 15 | + """Import CV-CUDA modules with informative error message if not installed. |
| 16 | +
|
| 17 | + Returns: |
| 18 | + tuple: (cvcuda, nvcv) modules. |
| 19 | +
|
| 20 | + Raises: |
| 21 | + RuntimeError: If CV-CUDA is not installed. |
| 22 | + """ |
| 23 | + try: |
| 24 | + import cvcuda # type: ignore[import-not-found] |
| 25 | + import nvcv # type: ignore[import-not-found] |
| 26 | + |
| 27 | + return cvcuda, nvcv |
| 28 | + except ImportError as e: |
| 29 | + raise RuntimeError( |
| 30 | + "CV-CUDA is required but not installed. " |
| 31 | + "Please install it following the instructions at " |
| 32 | + "https://github.com/CVCUDA/CV-CUDA or via pip: " |
| 33 | + "`pip install cvcuda-cu12` (for CUDA 12) or `pip install cvcuda-cu11` (for CUDA 11)." |
| 34 | + ) from e |
| 35 | + |
| 36 | + |
14 | 37 | @torch.jit.unused |
15 | 38 | def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tensors.Image: |
16 | 39 | """See :class:`~torchvision.transforms.v2.ToImage` for details.""" |
@@ -45,7 +68,7 @@ def _infer_nvcv_format(img_tensor: torch.Tensor): |
45 | 68 | TypeError: If dtype is not supported. |
46 | 69 | ValueError: If number of channels is not 3. |
47 | 70 | """ |
48 | | - import nvcv # type: ignore[import-not-found] |
| 71 | + _, nvcv = _import_cvcuda_modules() |
49 | 72 |
|
50 | 73 | num_channels = img_tensor.shape[2] |
51 | 74 | dtype = img_tensor.dtype |
@@ -77,8 +100,7 @@ def to_nvcv_tensor(pic) -> "nvcv.Tensor": |
77 | 100 | Returns: |
78 | 101 | nvcv.Tensor: Image converted to nvcv.Tensor with NHWC layout. |
79 | 102 | """ |
80 | | - import cvcuda # type: ignore[import-not-found] |
81 | | - import nvcv # type: ignore[import-not-found] |
| 103 | + cvcuda, nvcv = _import_cvcuda_modules() |
82 | 104 |
|
83 | 105 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
84 | 106 | _log_api_usage_once(to_nvcv_tensor) |
@@ -119,7 +141,7 @@ def nvcv_to_tensor(nvcv_img: "nvcv.Tensor") -> torch.Tensor: |
119 | 141 | Returns: |
120 | 142 | torch.Tensor: Converted image in NCHW format (batched). |
121 | 143 | """ |
122 | | - import nvcv # type: ignore[import-not-found] |
| 144 | + _, nvcv = _import_cvcuda_modules() |
123 | 145 |
|
124 | 146 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
125 | 147 | _log_api_usage_once(nvcv_to_tensor) |
|
0 commit comments