Skip to content

Commit a0f594e

Browse files
committed
remove device_variant stuff, parametrize test w cuda device
1 parent e307707 commit a0f594e

File tree

5 files changed

+81
-72
lines changed

5 files changed

+81
-72
lines changed

src/torchcodec/_core/custom_ops.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str device=\"cpu\", str device_variant=\"ffmpeg\", int? crf=None) -> ()");
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str device=\"cpu\", int? crf=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str device=\"cpu\", str device_variant=\"ffmpeg\", int? crf=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str device=\"cpu\", int? crf=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str device=\"cpu\", str device_variant=\"ffmpeg\",int? crf=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str device=\"cpu\",int? crf=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -604,14 +604,12 @@ void encode_video_to_file(
604604
int64_t frame_rate,
605605
std::string_view file_name,
606606
std::string_view device = "cpu",
607-
std::string_view device_variant = "ffmpeg",
608607
std::optional<int64_t> crf = std::nullopt) {
609608
VideoStreamOptions videoStreamOptions;
610609
videoStreamOptions.crf = crf;
611610

612-
validateDeviceInterface(std::string(device), std::string(device_variant));
613611
videoStreamOptions.device = torch::Device(std::string(device));
614-
videoStreamOptions.deviceVariant = device_variant;
612+
videoStreamOptions.deviceVariant = "ffmpeg";
615613
VideoEncoder(
616614
frames,
617615
validateInt64ToInt(frame_rate, "frame_rate"),
@@ -625,15 +623,13 @@ at::Tensor encode_video_to_tensor(
625623
int64_t frame_rate,
626624
std::string_view format,
627625
std::string_view device = "cpu",
628-
std::string_view device_variant = "ffmpeg",
629626
std::optional<int64_t> crf = std::nullopt) {
630627
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
631628
VideoStreamOptions videoStreamOptions;
632629
videoStreamOptions.crf = crf;
633630

634-
validateDeviceInterface(std::string(device), std::string(device_variant));
635631
videoStreamOptions.device = torch::Device(std::string(device));
636-
videoStreamOptions.deviceVariant = device_variant;
632+
videoStreamOptions.deviceVariant = "ffmpeg";
637633
return VideoEncoder(
638634
frames,
639635
validateInt64ToInt(frame_rate, "frame_rate"),
@@ -649,7 +645,6 @@ void _encode_video_to_file_like(
649645
std::string_view format,
650646
int64_t file_like_context,
651647
std::string_view device = "cpu",
652-
std::string_view device_variant = "ffmpeg",
653648
std::optional<int64_t> crf = std::nullopt) {
654649
auto fileLikeContext =
655650
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
@@ -660,9 +655,8 @@ void _encode_video_to_file_like(
660655
VideoStreamOptions videoStreamOptions;
661656
videoStreamOptions.crf = crf;
662657

663-
validateDeviceInterface(std::string(device), std::string(device_variant));
664658
videoStreamOptions.device = torch::Device(std::string(device));
665-
videoStreamOptions.deviceVariant = device_variant;
659+
videoStreamOptions.deviceVariant = "ffmpeg";
666660

667661
VideoEncoder encoder(
668662
frames,

src/torchcodec/_core/ops.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ def encode_video_to_file_like(
213213
format: str,
214214
file_like: Union[io.RawIOBase, io.BufferedIOBase],
215215
device: str = "cpu",
216-
device_variant: str = "ffmpeg",
217216
crf: Optional[int] = None,
218217
) -> None:
219218
"""Encode video frames to a file-like object.
@@ -224,7 +223,6 @@ def encode_video_to_file_like(
224223
format: Video format (e.g., "mp4", "mov", "mkv")
225224
file_like: File-like object that supports write() and seek() methods
226225
device: Device to use for encoding (default: "cpu")
227-
device_variant:
228226
crf: Optional constant rate factor for encoding quality
229227
"""
230228
assert _pybind_ops is not None
@@ -235,7 +233,6 @@ def encode_video_to_file_like(
235233
format,
236234
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
237235
device,
238-
device_variant,
239236
crf,
240237
)
241238

@@ -325,7 +322,6 @@ def encode_video_to_file_abstract(
325322
frame_rate: int,
326323
filename: str,
327324
device: str = "cpu",
328-
device_variant: str = "ffmpeg",
329325
crf: Optional[int] = None,
330326
) -> None:
331327
return
@@ -337,7 +333,6 @@ def encode_video_to_tensor_abstract(
337333
frame_rate: int,
338334
format: str,
339335
device: str = "cpu",
340-
device_variant: str = "ffmpeg",
341336
crf: Optional[int] = None,
342337
) -> torch.Tensor:
343338
return torch.empty([], dtype=torch.long)
@@ -350,7 +345,6 @@ def _encode_video_to_file_like_abstract(
350345
format: str,
351346
file_like_context: int,
352347
device: str = "cpu",
353-
device_variant: str = "ffmpeg",
354348
crf: Optional[int] = None,
355349
) -> None:
356350
return

src/torchcodec/encoders/_video_encoder.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch import device as torch_device, Tensor
66

77
from torchcodec import _core
8-
from torchcodec.decoders._decoder_utils import _get_cuda_backend
98

109

1110
class VideoEncoder:
@@ -43,14 +42,6 @@ def __init__(
4342
if isinstance(device, torch_device):
4443
device = str(device)
4544

46-
# Check if beta variant is being used and reject it
47-
device_variant = _get_cuda_backend()
48-
if "cuda" in device.lower() and device_variant == "beta":
49-
raise ValueError(
50-
"The beta CUDA backend is not supported for video encoding. "
51-
"Please use device='cuda' without the beta backend context manager."
52-
)
53-
5445
self._frames = frames
5546
self._frame_rate = frame_rate
5647
self._device = device

test/test_encoders.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -573,12 +573,6 @@ class TestVideoEncoder:
573573
def decode(self, source=None) -> torch.Tensor:
574574
return VideoDecoder(source).get_frames_in_range(start=0, stop=60)
575575

576-
def save_image(self, a, b, name):
577-
from torchvision.io import write_png
578-
from torchvision.utils import make_grid
579-
image = make_grid(torch.stack([a, b]), nrow=2).cpu()
580-
write_png(image, f"{name}.png")
581-
582576
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
583577
def test_bad_input_parameterized(self, tmp_path, method):
584578
if method == "to_file":
@@ -642,15 +636,20 @@ def test_bad_input(self, tmp_path):
642636
encoder.to_tensor(format="bad_format")
643637

644638
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
645-
def test_contiguity(self, method, tmp_path):
639+
@pytest.mark.parametrize(
640+
"device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
641+
)
642+
def test_contiguity(self, method, tmp_path, device):
646643
# Ensure that 2 sets of video frames with the same pixel values are encoded
647644
# in the same way, regardless of their memory layout. Here we encode 2 equal
648645
# frame tensors, one is contiguous while the other is non-contiguous.
649646

650-
num_frames, channels, height, width = 5, 3, 64, 64
651-
contiguous_frames = torch.randint(
652-
0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8
653-
).contiguous()
647+
num_frames, channels, height, width = 5, 3, 256, 256
648+
contiguous_frames = (
649+
(torch.rand(num_frames, channels, height, width) * 255)
650+
.to(torch.uint8)
651+
.contiguous()
652+
)
654653
assert contiguous_frames.is_contiguous()
655654

656655
# Permute NCHW to NHWC, then update the memory layout, then permute back
@@ -668,14 +667,14 @@ def test_contiguity(self, method, tmp_path):
668667
def encode_to_tensor(frames):
669668
if method == "to_file":
670669
dest = str(tmp_path / "output.mp4")
671-
VideoEncoder(frames, frame_rate=30).to_file(dest=dest)
670+
VideoEncoder(frames, frame_rate=30, device=device).to_file(dest=dest)
672671
with open(dest, "rb") as f:
673672
return torch.frombuffer(f.read(), dtype=torch.uint8).clone()
674673
elif method == "to_tensor":
675-
return VideoEncoder(frames, frame_rate=30).to_tensor(format="mp4")
674+
return VideoEncoder(frames, frame_rate=30, device=device).to_tensor(format="mp4")
676675
elif method == "to_file_like":
677676
file_like = io.BytesIO()
678-
VideoEncoder(frames, frame_rate=30).to_file_like(
677+
VideoEncoder(frames, frame_rate=30, device=device).to_file_like(
679678
file_like, format="mp4"
680679
)
681680
return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8)
@@ -708,30 +707,16 @@ def test_device_video_encoder(self, method, device, tmp_path):
708707
encoder.to_file(dest=dest)
709708
# Verify file was created
710709
assert Path(dest).exists()
711-
self.save_image(
712-
frames[0],
713-
self.decode(Path(dest)).data[0],
714-
name=f"{device}_to_file",
715-
)
716710
elif method == "to_tensor":
717711
encoded = encoder.to_tensor(format="mp4")
718712
assert encoded.dtype == torch.uint8
719713
assert encoded.ndim == 1
720714
assert encoded.numel() > 0
721-
self.save_image(
722-
frames[0],
723-
self.decode(encoded).data[0],
724-
name=f"{device}_to_tensor",
725-
)
726715
elif method == "to_file_like":
727716
file_like = io.BytesIO()
728717
encoder.to_file_like(file_like, format="mp4")
729718
encoded_bytes = file_like.getvalue()
730719
assert len(encoded_bytes) > 0
731-
self.save_image(
732-
frames[0],
733-
self.decode(encoded_bytes).data[0],
734-
name=f"{device}_to_file_like",
735-
)
736720
else:
737721
raise ValueError(f"Unknown method: {method}")
722+

test/test_ops.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,10 @@ def decode(self, source=None) -> torch.Tensor:
11591159
"format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow))
11601160
)
11611161
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
1162-
def test_video_encoder_round_trip(self, tmp_path, format, method):
1162+
@pytest.mark.parametrize(
1163+
"device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
1164+
)
1165+
def test_video_encoder_round_trip(self, tmp_path, format, method, device):
11631166
# Test that decode(encode(decode(frames))) == decode(frames)
11641167
ffmpeg_version = get_ffmpeg_major_version()
11651168
# In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm.
@@ -1174,9 +1177,10 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
11741177
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
11751178
source_frames = self.decode(TEST_SRC_2_720P.path).data
11761179

1180+
# Frame rate is fixed with num frames decoded
11771181
params = dict(
1178-
frame_rate=30, crf=0
1179-
) # Frame rate is fixed with num frames decoded
1182+
frame_rate=30, crf=0, device=device
1183+
)
11801184
if method == "to_file":
11811185
encoded_path = str(tmp_path / f"encoder_output.{format}")
11821186
encode_video_to_file(
@@ -1207,9 +1211,10 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
12071211

12081212
# If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels
12091213
# are within a higher tolerance.
1210-
if ffmpeg_version == 6:
1211-
assert_close = partial(assert_tensor_close_on_at_least, percentage=99)
1214+
if ffmpeg_version == 6 or device == "cuda":
12121215
atol = 15
1216+
percentage = 98
1217+
assert_close = partial(assert_tensor_close_on_at_least, percentage=percentage)
12131218
else:
12141219
assert_close = torch.testing.assert_close
12151220
atol = 2
@@ -1230,7 +1235,10 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
12301235
),
12311236
)
12321237
@pytest.mark.parametrize("method", ("to_tensor", "to_file_like"))
1233-
def test_against_to_file(self, tmp_path, format, method):
1238+
@pytest.mark.parametrize(
1239+
"device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
1240+
)
1241+
def test_against_to_file(self, tmp_path, format, method, device):
12341242
# Test that to_file, to_tensor, and to_file_like produce the same results
12351243
ffmpeg_version = get_ffmpeg_major_version()
12361244
if format == "webm" and (
@@ -1239,7 +1247,7 @@ def test_against_to_file(self, tmp_path, format, method):
12391247
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
12401248

12411249
source_frames = self.decode(TEST_SRC_2_720P.path).data
1242-
params = dict(frame_rate=30, crf=0)
1250+
params = dict(frame_rate=30, crf=0, device=device)
12431251

12441252
encoded_file = tmp_path / f"output.{format}"
12451253
encode_video_to_file(frames=source_frames, filename=str(encoded_file), **params)
@@ -1313,10 +1321,22 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format, device):
13131321
str(frame_rate),
13141322
"-i",
13151323
temp_raw_path,
1324+
]
1325+
1326+
# Use NVENC encoder when device is CUDA and format has an NVENC codec
1327+
if device == "cuda":
1328+
if format in ("mp4", "mov", "mkv"):
1329+
ffmpeg_cmd.extend(["-c:v", "h264_nvenc"])
1330+
elif format == "webm":
1331+
ffmpeg_cmd.extend(["-c:v", "vp9_nvenc"]) # Use NVENC for VP9
1332+
# TODO-VideoEncoder: formats "flv", "avi" should also use respective NVENC codecs,
1333+
# but do not auto select them.
1334+
1335+
ffmpeg_cmd.extend([
13161336
"-crf",
13171337
str(crf),
13181338
ffmpeg_encoded_path,
1319-
]
1339+
])
13201340
subprocess.run(ffmpeg_cmd, check=True)
13211341

13221342
# Encode with our video encoder
@@ -1347,7 +1367,10 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format, device):
13471367
ff_frame, enc_frame, percentage=percentage, atol=2
13481368
)
13491369

1350-
def test_to_file_like_custom_file_object(self):
1370+
@pytest.mark.parametrize(
1371+
"device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
1372+
)
1373+
def test_to_file_like_custom_file_object(self, device):
13511374
"""Test to_file_like with a custom file-like object that implements write and seek."""
13521375

13531376
class CustomFileObject:
@@ -1366,32 +1389,54 @@ def get_encoded_data(self):
13661389
source_frames = self.decode(TEST_SRC_2_720P.path).data
13671390
file_like = CustomFileObject()
13681391
encode_video_to_file_like(
1369-
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
1392+
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like, device=device
13701393
)
13711394
decoded_samples = self.decode(file_like.get_encoded_data())
13721395

1373-
torch.testing.assert_close(
1396+
ffmpeg_version = get_ffmpeg_major_version()
1397+
if device == "cuda":
1398+
atol = 15
1399+
percentage = 98
1400+
assert_close = partial(assert_tensor_close_on_at_least, percentage=percentage)
1401+
else:
1402+
assert_close = torch.testing.assert_close
1403+
atol = 2
1404+
1405+
assert_close(
13741406
decoded_samples.data,
13751407
source_frames,
1376-
atol=2,
1408+
atol=atol,
13771409
rtol=0,
13781410
)
13791411

1380-
def test_to_file_like_real_file(self, tmp_path):
1412+
@pytest.mark.parametrize(
1413+
"device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
1414+
)
1415+
def test_to_file_like_real_file(self, tmp_path, device):
13811416
"""Test to_file_like with a real file opened in binary write mode."""
13821417
source_frames = self.decode(TEST_SRC_2_720P.path).data
13831418
file_path = tmp_path / "test_file_like.mp4"
13841419

13851420
with open(file_path, "wb") as file_like:
13861421
encode_video_to_file_like(
1387-
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
1422+
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like, device=device
13881423
)
13891424
decoded_samples = self.decode(str(file_path))
13901425

1391-
torch.testing.assert_close(
1426+
# Use adaptive tolerance based on device and FFmpeg version, consistent with test_video_encoder_round_trip
1427+
ffmpeg_version = get_ffmpeg_major_version()
1428+
if device == "cuda":
1429+
atol = 15
1430+
percentage = 98
1431+
assert_close = partial(assert_tensor_close_on_at_least, percentage=percentage)
1432+
else:
1433+
assert_close = torch.testing.assert_close
1434+
atol = 2
1435+
1436+
assert_close(
13921437
decoded_samples.data,
13931438
source_frames,
1394-
atol=2,
1439+
atol=atol,
13951440
rtol=0,
13961441
)
13971442

0 commit comments

Comments
 (0)