Skip to content

Commit 144c394

Browse files
Daniel FloresDan-Flores
authored andcommitted
pass device around, basic function test
1 parent 26ed10e commit 144c394

File tree

5 files changed

+87
-9
lines changed

5 files changed

+87
-9
lines changed

src/torchcodec/_core/custom_ops.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str device=\"cpu\", str device_variant=\"ffmpeg\", int? crf=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str device=\"cpu\", str device_variant=\"ffmpeg\", int? crf=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str device=\"cpu\", str device_variant=\"ffmpeg\",int? crf=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -603,9 +603,15 @@ void encode_video_to_file(
603603
const at::Tensor& frames,
604604
int64_t frame_rate,
605605
std::string_view file_name,
606+
std::string_view device = "cpu",
607+
std::string_view device_variant = "ffmpeg",
606608
std::optional<int64_t> crf = std::nullopt) {
607609
VideoStreamOptions videoStreamOptions;
608610
videoStreamOptions.crf = crf;
611+
612+
validateDeviceInterface(std::string(device), std::string(device_variant));
613+
videoStreamOptions.device = torch::Device(std::string(device));
614+
videoStreamOptions.deviceVariant = device_variant;
609615
VideoEncoder(
610616
frames,
611617
validateInt64ToInt(frame_rate, "frame_rate"),
@@ -618,10 +624,16 @@ at::Tensor encode_video_to_tensor(
618624
const at::Tensor& frames,
619625
int64_t frame_rate,
620626
std::string_view format,
627+
std::string_view device = "cpu",
628+
std::string_view device_variant = "ffmpeg",
621629
std::optional<int64_t> crf = std::nullopt) {
622630
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
623631
VideoStreamOptions videoStreamOptions;
624632
videoStreamOptions.crf = crf;
633+
634+
validateDeviceInterface(std::string(device), std::string(device_variant));
635+
videoStreamOptions.device = torch::Device(std::string(device));
636+
videoStreamOptions.deviceVariant = device_variant;
625637
return VideoEncoder(
626638
frames,
627639
validateInt64ToInt(frame_rate, "frame_rate"),
@@ -636,6 +648,8 @@ void _encode_video_to_file_like(
636648
int64_t frame_rate,
637649
std::string_view format,
638650
int64_t file_like_context,
651+
std::string_view device = "cpu",
652+
std::string_view device_variant = "ffmpeg",
639653
std::optional<int64_t> crf = std::nullopt) {
640654
auto fileLikeContext =
641655
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
@@ -646,6 +660,10 @@ void _encode_video_to_file_like(
646660
VideoStreamOptions videoStreamOptions;
647661
videoStreamOptions.crf = crf;
648662

663+
validateDeviceInterface(std::string(device), std::string(device_variant));
664+
videoStreamOptions.device = torch::Device(std::string(device));
665+
videoStreamOptions.deviceVariant = device_variant;
666+
649667
VideoEncoder encoder(
650668
frames,
651669
validateInt64ToInt(frame_rate, "frame_rate"),

src/torchcodec/_core/ops.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ def encode_video_to_file_like(
212212
frame_rate: int,
213213
format: str,
214214
file_like: Union[io.RawIOBase, io.BufferedIOBase],
215+
device: str = "cpu",
216+
device_variant: str = "ffmpeg",
215217
crf: Optional[int] = None,
216218
) -> None:
217219
"""Encode video frames to a file-like object.
@@ -221,6 +223,8 @@ def encode_video_to_file_like(
221223
frame_rate: Frame rate in frames per second
222224
format: Video format (e.g., "mp4", "mov", "mkv")
223225
file_like: File-like object that supports write() and seek() methods
226+
device: Device to use for encoding (default: "cpu")
227+
device_variant:
224228
crf: Optional constant rate factor for encoding quality
225229
"""
226230
assert _pybind_ops is not None
@@ -230,6 +234,8 @@ def encode_video_to_file_like(
230234
frame_rate,
231235
format,
232236
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
237+
device,
238+
device_variant,
233239
crf,
234240
)
235241

@@ -318,7 +324,9 @@ def encode_video_to_file_abstract(
318324
frames: torch.Tensor,
319325
frame_rate: int,
320326
filename: str,
321-
crf: Optional[int],
327+
device: str = "cpu",
328+
device_variant: str = "ffmpeg",
329+
crf: Optional[int] = None,
322330
) -> None:
323331
return
324332

@@ -328,7 +336,9 @@ def encode_video_to_tensor_abstract(
328336
frames: torch.Tensor,
329337
frame_rate: int,
330338
format: str,
331-
crf: Optional[int],
339+
device: str = "cpu",
340+
device_variant: str = "ffmpeg",
341+
crf: Optional[int] = None,
332342
) -> torch.Tensor:
333343
return torch.empty([], dtype=torch.long)
334344

@@ -339,6 +349,8 @@ def _encode_video_to_file_like_abstract(
339349
frame_rate: int,
340350
format: str,
341351
file_like_context: int,
352+
device: str = "cpu",
353+
device_variant: str = "ffmpeg",
342354
crf: Optional[int] = None,
343355
) -> None:
344356
return

src/torchcodec/encoders/_video_encoder.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from pathlib import Path
2-
from typing import Union
2+
from typing import Optional, Union
33

44
import torch
5-
from torch import Tensor
5+
from torch import device as torch_device, Tensor
66

77
from torchcodec import _core
8+
from torchcodec.decoders._decoder_utils import _get_cuda_backend
89

910

1011
class VideoEncoder:
@@ -16,6 +17,9 @@ class VideoEncoder:
1617
C is 3 channels (RGB), H is height, and W is width.
1718
Values must be uint8 in the range ``[0, 255]``.
1819
frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
20+
device (str or torch.device, optional): The device to use for encoding. Default: "cpu".
21+
If you pass a CUDA device, frames will be encoded on GPU.
22+
Note: The "beta" CUDA backend is not supported for encoding.
1923
"""
2024

2125
def __init__(self, frames: Tensor, *, frame_rate: int):
@@ -29,8 +33,21 @@ def __init__(self, frames: Tensor, *, frame_rate: int):
2933
if frame_rate <= 0:
3034
raise ValueError(f"{frame_rate = } must be > 0.")
3135

36+
# Validate and store device
37+
if isinstance(device, torch_device):
38+
device = str(device)
39+
40+
# Check if beta variant is being used and reject it
41+
device_variant = _get_cuda_backend()
42+
if "cuda" in device.lower() and device_variant == "beta":
43+
raise ValueError(
44+
"The beta CUDA backend is not supported for video encoding. "
45+
"Please use device='cuda' without the beta backend context manager."
46+
)
47+
3248
self._frames = frames
3349
self._frame_rate = frame_rate
50+
self._device = device
3451

3552
def to_file(
3653
self,
@@ -47,6 +64,7 @@ def to_file(
4764
frames=self._frames,
4865
frame_rate=self._frame_rate,
4966
filename=str(dest),
67+
device=self._device,
5068
)
5169

5270
def to_tensor(
@@ -66,6 +84,7 @@ def to_tensor(
6684
frames=self._frames,
6785
frame_rate=self._frame_rate,
6886
format=format,
87+
device=self._device,
6988
)
7089

7190
def to_file_like(
@@ -89,4 +108,5 @@ def to_file_like(
89108
frame_rate=self._frame_rate,
90109
format=format,
91110
file_like=file_like,
111+
device=self._device,
92112
)

test/test_encoders.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,3 +676,33 @@ def encode_to_tensor(frames):
676676
torch.testing.assert_close(
677677
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
678678
)
679+
680+
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
681+
@pytest.mark.parametrize(
682+
"device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
683+
)
684+
def test_device_video_encoder(self, method, device, tmp_path):
685+
# Test that encoding works on CPU and CUDA devices
686+
num_frames, channels, height, width = 5, 3, 64, 64
687+
frames = (torch.rand(num_frames, channels, height, width) * 255).to(torch.uint8)
688+
689+
encoder = VideoEncoder(frames, frame_rate=30, device=device)
690+
691+
if method == "to_file":
692+
dest = str(tmp_path / "output.mp4")
693+
encoder.to_file(dest=dest)
694+
# Verify file was created
695+
assert Path(dest).exists()
696+
elif method == "to_tensor":
697+
encoded = encoder.to_tensor(format="mp4")
698+
assert encoded.dtype == torch.uint8
699+
assert encoded.ndim == 1
700+
assert encoded.numel() > 0
701+
elif method == "to_file_like":
702+
file_like = io.BytesIO()
703+
encoder.to_file_like(file_like, format="mp4")
704+
encoded_bytes = file_like.getvalue()
705+
assert len(encoded_bytes) > 0
706+
else:
707+
raise ValueError(f"Unknown method: {method}")
708+
class VideoEncoder

test/test_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,8 +1375,6 @@ def get_encoded_data(self):
13751375

13761376
def test_to_file_like_real_file(self, tmp_path):
13771377
"""Test to_file_like with a real file opened in binary write mode."""
1378-
if get_ffmpeg_major_version() == 6:
1379-
pytest.skip("Skipping round trip test for FFmpeg 6")
13801378
source_frames = self.decode(TEST_SRC_2_720P.path).data
13811379
file_path = tmp_path / "test_file_like.mp4"
13821380

0 commit comments

Comments
 (0)