Skip to content

Conversation

@AntoineSimoulin
Copy link
Member

@AntoineSimoulin AntoineSimoulin commented Nov 6, 2025

Summary

This PR provides the first building blocks for CV-CUDA integration in torchvision. We add the functionals to_cvcuda_tensor and cvcuda_to_tensor to transform from torch.Tensor to cvcuda.Tensor and back. We also implement the corresponding class transforms ToCVCUDATensor and CVCUDAToTensor. These transforms require CV-CUDA to be installed.

How to use

from PIL import Image
import torchvision.transforms.v2.functional as F 

# Create rand ``torch.Tensor`` image (must be 3-channel RGB/Gray and have batch dimension)
img_tensor = torch.randint(0, 256, (1, 3, 320, 240), dtype=torch.uint8)

# Convert to ``cvcuda.Tensor`` (will be uploaded to CUDA)
cvcuda_tensor = F.to_cvcuda_tensor(img_tensor)

# Convert back to ``torch.Tensor``
img_tensor = F.cvcuda_to_tensor(cvcuda_tensor)

Note

  • cvcuda.Tensor are automatically converted to NHWC shape (since most CV-CUDA transforms only support this shape)
  • Only 3-channel RGB images and 1-channel grayscale are supported for now
  • Input tensors will be uploaded to CUDA device when converting to CV-CUDA tensors
  • CV-CUDA must be installed: pip install cvcuda-cu12 (CUDA 12) or pip install cvcuda-cu11 (CUDA 11)

Run unit tests

pytest test/test_transforms_v2.py -k "cvcuda"
...
338 passed, 9774 deselected in 1.65s

Differential Revision: D85862362

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9259

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures

As of commit ccf1a36 with merge base acccf86 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the cla signed label Nov 6, 2025
@meta-codesync
Copy link

meta-codesync bot commented Nov 6, 2025

@AntoineSimoulin has exported this pull request. If you are a Meta employee, you can view the originating Diff in D85862362.

@AntoineSimoulin
Copy link
Member Author

NOTES:

  • Should we upload tensors .to(cuda) directly inside to_nvcv_tensor function?
  • Should we keep CVCUDA specific functions inside isolated "_cvcuda" files or within existing files (e.g. "_geometry")?
  • NVCV tensors are automatically converted to NHWC format. Contrary to torchvision convention, which relies on NCHW format. If we call to_nvcv_tensor and then nvcv_to_tensor, we might have an output tensor with a different shape.
  • Are the image format from _infer_nvcv_format exhaustive and aligned with what is done for PIL?
  • What is the difference between nvcv.Image and nvcv.Tensor?

@meta-codesync
Copy link

meta-codesync bot commented Nov 7, 2025

@AntoineSimoulin has imported this pull request. If you are a Meta employee, you can view this in D85862362.

AntoineSimoulin added a commit to AntoineSimoulin/vision that referenced this pull request Nov 10, 2025
Summary:
Users have to explicitly opt-in for those transforms. Here we provide the first building block for this interface. We add the functionals `to_nvcv_image` and `nvcv_to_tensor` to transform `torch.Tensor` to `nvcv.Tensor`. We also implement the corresponding class transforms `ToNVCVImage` and `NVCVToTensor`.

## How to use

```python
from PIL import Image
import torchvision.transforms.v2.functional as F

orig_img = Image.open("leaning_tower.jpg")
img_tensor = F.pil_to_tensor(orig_img)
nvcv_tensor = F.to_nvcv_tensor(img_tensor.cuda())
img_tensor = F.nvcv_to_tensor(nvcv_tensor)
```

> [!NOTE]
> NVCV tensors are automatically converted to NHWC format. Contrary to torchvision convention, which relies on NCHW format.

## Run unit tests

```bash
pytest test/test_cvcuda.py
...
37 passed in 0.15s
```


Test Plan:
```python
from torchvision import _is_cvcuda_available

_is_cvcuda_available()
```

## Run tests

```bash
buck test fbcode//mode/opt fbcode//pytorch/vision/test:torchvision_cvcuda
...
Tests finished: Pass 38. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Differential Revision: D85862362

Pulled By: AntoineSimoulin
AntoineSimoulin added a commit to AntoineSimoulin/vision that referenced this pull request Nov 10, 2025
Summary:
Users have to explicitly opt-in for those transforms. Here we provide the first building block for this interface. We add the functionals `to_nvcv_image` and `nvcv_to_tensor` to transform `torch.Tensor` to `nvcv.Tensor`. We also implement the corresponding class transforms `ToNVCVImage` and `NVCVToTensor`.

## How to use

```python
from PIL import Image
import torchvision.transforms.v2.functional as F

orig_img = Image.open("leaning_tower.jpg")
img_tensor = F.pil_to_tensor(orig_img)
nvcv_tensor = F.to_nvcv_tensor(img_tensor.cuda())
img_tensor = F.nvcv_to_tensor(nvcv_tensor)
```

> [!NOTE]
> NVCV tensors are automatically converted to NHWC format. Contrary to torchvision convention, which relies on NCHW format.

## Run unit tests

```bash
pytest test/test_cvcuda.py
...
37 passed in 0.15s
```


Test Plan:
```python
from torchvision import _is_cvcuda_available

_is_cvcuda_available()
```

## Run tests

```bash
buck test fbcode//mode/opt fbcode//pytorch/vision/test:torchvision_cvcuda
...
Tests finished: Pass 38. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Differential Revision: D85862362

Pulled By: AntoineSimoulin
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this @AntoineSimoulin ! This looks great, I made some suggestions about potentially simplify the features we need to support right now, and about potentially preserving the existing file structure. Let's chat more

@justincdavis
Copy link

This looks like a great base implementation, thanks for getting this started @AntoineSimoulin !

My immediate comment would be that we should only be using the cvcuda import. Since cvcuda has been published on PyPi, all of the nvcv objects are aliased inside the cvcuda module. In the 0.16.0 release (to be released end of week) we have removed the nvcv Python module completely to simplify the library. All currently published verisons of PyPi will work using only the cvcuda Python module. This will also reduce any naming convention confusion between NVCV/CV-CUDA tensors which may come up for users.

F.to_cvcuda_tensor(img_data)

@pytest.mark.parametrize("num_channels", [1, 3])
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the int16/int32 dtypes also be tested in the round trip tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure regarding RGB, CVCUDA only supports torch.uint8 (cvcuda.Format.RGB8) and torch.float32 (cvcuda.Format.RGBf32)?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, great question! The formats are actually relevant for the cvcuda.Image datatype and tensors do not have this format info. CV-CUDA tensors should be able to read most datatypes from torch directly. I saw you posted an initial note asking about the difference between image/tensor, and I will post an explanation on this.

For tensors, most operators should support either unsigned int, signed int, floats, or some combination of those.

Would it be helpful at this stage to have a conclusive list of the datatypes/channels CV-CUDA operators can utilize?

Copy link
Member Author

@AntoineSimoulin AntoineSimoulin Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha! I think at this point we should remove the _validate_cvcuda_dtype function since cvcuda should support all data types supported by torchvision (torch.uint8, torch.int16, torch.int32, torch.float32, torch.float64). Maybe just reject torch.float16?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the following PRs, yes I think it would be useful to have a list of the datatypes/channels CV-CUDA operators can utilize!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rejecting torch.float16 sounds like a good solution, we are tracking support for fp16 but no timeline as of yet.

AntoineSimoulin added a commit to AntoineSimoulin/vision that referenced this pull request Nov 12, 2025
Summary:
Summary
-------

This PR provides the first building blocks for CV-CUDA integration in torchvision. We add the functionals `to_cvcuda_tensor` and `cvcuda_to_tensor` to transform from `torch.Tensor` to `cvcuda.Tensor` and back. We also implement the corresponding class transforms `ToCVCUDATensor` and `CVCUDAToTensor`.

**Key features:**

*   **3-channel RGB support only**: Simplified API focusing on the most common use case (RGB images)
*   **Supported data types**: `torch.uint8` (RGB8 format) and `torch.float32` (RGBf32 format)
*   **Lossless round-trip conversions**: Exact data preservation when converting PyTorch ↔ CV-CUDA
*   **Informative error messages**: Helpful installation instructions when CV-CUDA is not available
*   **Batch-aware**: Handles both unbatched (CHW) and batched (NCHW) tensors

Users must explicitly opt-in for these transforms, which require CV-CUDA to be installed.

How to use
----------

```python
from PIL import Image
import torchvision.transforms.v2.functional as F 

# Load and convert image to PyTorch tensor
orig_img = Image.open("leaning_tower.jpg")
img_tensor = F.pil_to_tensor(orig_img)  

# Convert to CV-CUDA tensor (must be 3-channel RGB on CUDA)
cvcuda_tensor = F.to_cvcuda_tensor(img_tensor.cuda())

# Convert back to PyTorch tensor
img_tensor = F.cvcuda_to_tensor(cvcuda_tensor)
```

> [!NOTE]
> 
> *   NVCV tensors are automatically converted to NHWC layout, contrary to torchvision's NCHW default
> *   Only 3-channel RGB images and 1-channel grayscale are supported for now
> *   Input tensors will be uploaded to CUDA device when converting to CV-CUDA tensors
> *   CV-CUDA must be installed: `pip install cvcuda-cu12` (CUDA 12) or `pip install cvcuda-cu11` (CUDA 11)

Run unit tests
--------------
## Run unit tests

```bash
pytest test/test_transforms_v2.py -k "cvcuda"
...
35 passed, 4 skipped, 9774 deselected in 1.12s
```


Test Plan:
```python
from torchvision import _is_cvcuda_available

_is_cvcuda_available()
```

## Run tests

```bash
buck test fbcode//mode/opt fbcode//pytorch/vision/test:torchvision_transforms_v2
...
Tests finished: Pass 38. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Differential Revision: D85862362

Pulled By: AntoineSimoulin
@justincdavis
Copy link

The primary purpose for cvcuda.Image is that it can attach image metadata via the cvcuda.Format type. This allows the cvcuda.Image to represent formats such as NV12, YUV, etc. NV12 for example will have 2 planes with different resolutions, something which a cvcuda.Tensor is not easily equipped to handle. The cvcuda.Format informs operators/external libraries about how to interpret the data.

A cvcuda.Tensor on the other hand is simply an N-dimension array with no special metadata access patterns associated with it. You can represent RGB images via cvcuda.Tensor, but it cannot handle more complex formats such as NV12/YU with multi-planar allocations and alternative access patterns.

For the purposes of the CV-CUDA backend I believe we can exclusively use the cvcuda.Tensor.

@AntoineSimoulin AntoineSimoulin changed the title CVCUDA backend design Introducing CVCUDA backend Nov 13, 2025
@AntoineSimoulin AntoineSimoulin changed the title Introducing CVCUDA backend Introducing CVCUDA Backend Nov 13, 2025
Summary:
Summary
-------

This PR provides the first building blocks for CV-CUDA integration in torchvision. We add the functionals `to_cvcuda_tensor` and `cvcuda_to_tensor` to transform from `torch.Tensor` to `cvcuda.Tensor` and back. We also implement the corresponding class transforms `ToCVCUDATensor` and `CVCUDAToTensor`.

**Key features:**

*   **3-channel RGB support only**: Simplified API focusing on the most common use case (RGB images)
*   **Supported data types**: `torch.uint8` (RGB8 format) and `torch.float32` (RGBf32 format)
*   **Lossless round-trip conversions**: Exact data preservation when converting PyTorch ↔ CV-CUDA
*   **Informative error messages**: Helpful installation instructions when CV-CUDA is not available
*   **Batch-aware**: Handles both unbatched (CHW) and batched (NCHW) tensors

Users must explicitly opt-in for these transforms, which require CV-CUDA to be installed.

How to use
----------

```python
from PIL import Image
import torchvision.transforms.v2.functional as F 

# Load and convert image to PyTorch tensor
orig_img = Image.open("leaning_tower.jpg")
img_tensor = F.pil_to_tensor(orig_img)  

# Convert to CV-CUDA tensor (must be 3-channel RGB on CUDA)
cvcuda_tensor = F.to_cvcuda_tensor(img_tensor.cuda())

# Convert back to PyTorch tensor
img_tensor = F.cvcuda_to_tensor(cvcuda_tensor)
```

> [!NOTE]
> 
> *   NVCV tensors are automatically converted to NHWC layout, contrary to torchvision's NCHW default
> *   Only 3-channel RGB images and 1-channel grayscale are supported for now
> *   Input tensors will be uploaded to CUDA device when converting to CV-CUDA tensors
> *   CV-CUDA must be installed: `pip install cvcuda-cu12` (CUDA 12) or `pip install cvcuda-cu11` (CUDA 11)

Run unit tests
--------------

```bash
pytest test/test_transforms_v2.py -k "cvcuda"
...
35 passed, 4 skipped, 9774 deselected in 1.12s
```


Differential Revision: D85862362

Pulled By: AntoineSimoulin
img_data = img_data.cuda()
F.to_cvcuda_tensor(img_data)

@pytest.mark.parametrize("num_channels", [1, 3])
Copy link

@justincdavis justincdavis Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CV-CUDA can support 2/4-channel tensors, I missed this previously when discussing the validate format function. Same comment for test_round_trip_batched.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point! Right now torchvision focuses on 1 or 3 channel tensors so I propose we stay align with what is explicitly supported even though CV-CUDA supports more formats.

Args:
pic (torch.Tensor): Image to be converted to cvcuda.Tensor.
Tensor can be in CHW format (unbatched) or NCHW format (batched).
Only 1-channel and 3-channel images are supported.
Copy link

@justincdavis justincdavis Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should update this comment to include 2/4 channel images since we no longer filter those out with the validation function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justincdavis thanks for the comments. See my answer above!

@AntoineSimoulin
Copy link
Member Author

@justincdavis for nvcv.Tensors, can we use isinstance(output, cvcuda.Tensor). It seems it will raise an error?

@justincdavis
Copy link

@AntoineSimoulin I have not encountered this previously, as far as I am aware isinstance should work without issue. Do you have a MRE of where it is failing?

@AntoineSimoulin
Copy link
Member Author

Summary of the changes:

  • Cleaned up docstrings for ToCVCUDATensor and CVCUDAToTensor
  • Moved _import_cvcuda to "torchvision/transforms/v2/functional/_utils" since it is used in multiple files
  • Update to_cvcuda_tensor and cvcuda_to_tensor to only support batched NHWC layout as input (no single image with HWC layout) to simplify the logic from this initial commit
  • Cleaned up docstrings for to_cvcuda_tensor and cvcuda_to_tensor
  • Improved test parametrization in "test/transforms_v2"
  • Added kernel for get_size_image_cvcuda to be used in tests

@AntoineSimoulin
Copy link
Member Author

@justincdavis thanks for the confirmation, just wanted to double check and indeed I confirmed isinstance works without issue.

@justincdavis
Copy link

Other than the final comment I made, this looks great to me! Thank you @AntoineSimoulin !

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work @AntoineSimoulin
and thank you so much @justincdavis for the reviews!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants