diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 2d1033074..bcc53a6a8 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -18,11 +18,6 @@ class CpuDeviceInterface : public DeviceInterface { virtual ~CpuDeviceInterface() {} - std::optional findCodec( - [[maybe_unused]] const AVCodecID& codecId) override { - return std::nullopt; - } - virtual void initialize( const AVStream* avStream, const UniqueDecodingAVFormatContext& avFormatCtx, diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 0e20c5e8d..4011c7340 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -329,11 +329,40 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); } +namespace { +// Helper function to check if a codec supports CUDA hardware acceleration +bool codecSupportsCudaHardware(const AVCodec* codec) { + const AVCodecHWConfig* config = nullptr; + for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; ++j) { + if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { + return true; + } + } + return false; +} +} // namespace + // inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 // we have to do this because of an FFmpeg bug where hardware decoding is not // appropriately set, so we just go off and find the matching codec for the CUDA // device -std::optional CudaDeviceInterface::findCodec( + +std::optional CudaDeviceInterface::findEncoder( + const AVCodecID& codecId) { + void* i = nullptr; + const AVCodec* codec = nullptr; + while ((codec = av_codec_iterate(&i)) != nullptr) { + if (codec->id != codecId || !av_codec_is_encoder(codec)) { + continue; + } + if (codecSupportsCudaHardware(codec)) { + return codec; + } + } + return std::nullopt; +} + +std::optional CudaDeviceInterface::findDecoder( const AVCodecID& codecId) { void* i = nullptr; const AVCodec* codec = nullptr; @@ -342,12 +371,8 @@ std::optional CudaDeviceInterface::findCodec( continue; } - const AVCodecHWConfig* config = nullptr; - for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; - ++j) { - if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { - return codec; - } + if (codecSupportsCudaHardware(codec)) { + return codec; } } diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index c892bd49b..9c0c2fdb9 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -18,7 +18,8 @@ class CudaDeviceInterface : public DeviceInterface { virtual ~CudaDeviceInterface(); - std::optional findCodec(const AVCodecID& codecId) override; + std::optional findEncoder(const AVCodecID& codecId) override; + std::optional findDecoder(const AVCodecID& codecId) override; void initialize( const AVStream* avStream, diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 319fe01a8..3ef956056 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -46,7 +46,12 @@ class DeviceInterface { return device_; }; - virtual std::optional findCodec( + virtual std::optional findEncoder( + [[maybe_unused]] const AVCodecID& codecId) { + return std::nullopt; + }; + + virtual std::optional findDecoder( [[maybe_unused]] const AVCodecID& codecId) { return std::nullopt; }; diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 9546924c1..c9c983054 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -615,10 +615,23 @@ VideoEncoder::VideoEncoder( void VideoEncoder::initializeEncoder( const VideoStreamOptions& videoStreamOptions) { + deviceInterface_ = createDeviceInterface( + videoStreamOptions.device, videoStreamOptions.deviceVariant); + TORCH_CHECK( + deviceInterface_ != nullptr, + "Failed to create device interface. This should never happen, please report."); + const AVCodec* avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec); TORCH_CHECK(avCodec != nullptr, "Video codec not found"); + // Try to find a hardware-accelerated encoder if not using CPU + if (videoStreamOptions.device.type() != torch::kCPU) { + avCodec = + deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec) + .value_or(avCodec); + } + AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); avCodecContext_.reset(avCodecContext); @@ -662,12 +675,20 @@ void VideoEncoder::initializeEncoder( // Apply videoStreamOptions AVDictionary* options = nullptr; if (videoStreamOptions.crf.has_value()) { + // nvenc encoders use qp, others use crf (for C++ tests) + std::string_view quality_param = + (strstr(avCodec->name, "nvenc") == nullptr) ? "crf" : "qp"; av_dict_set( &options, - "crf", + quality_param.data(), std::to_string(videoStreamOptions.crf.value()).c_str(), 0); } + + // Register the hardware device context with the codec + // context before calling avcodec_open2(). + deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get()); + int status = avcodec_open2(avCodecContext_.get(), avCodec, &options); av_dict_free(&options); diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index c1055281a..05ed8855d 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -1,6 +1,7 @@ #pragma once #include #include "AVIOContextHolder.h" +#include "DeviceInterface.h" #include "FFMPEGCommon.h" #include "StreamOptions.h" @@ -177,6 +178,7 @@ class VideoEncoder { AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE; std::unique_ptr avioContextHolder_; + std::unique_ptr deviceInterface_; bool encodeWasCalled_ = false; }; diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 524ada777..0b59308ef 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -40,7 +40,7 @@ AVPacket* ReferenceAVPacket::operator->() { AVCodecOnlyUseForCallingAVFindBestStream makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec) { -#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100) +#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100) // FFmpeg < 5.0.3 return const_cast(codec); #else return codec; diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 72cd7afac..46021621e 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -435,7 +435,7 @@ void SingleStreamDecoder::addStream( // addStream() which is supposed to be generic if (mediaType == AVMEDIA_TYPE_VIDEO) { avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id) + deviceInterface_->findDecoder(streamInfo.stream->codecpar->codec_id) .value_or(avCodec)); } diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index b78160f6c..07cfb4dfd 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); m.def( - "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()"); + "encode_video_to_file(Tensor frames, int frame_rate, str filename, str device=\"cpu\", int? crf=None) -> ()"); m.def( - "encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor"); + "encode_video_to_tensor(Tensor frames, int frame_rate, str format, str device=\"cpu\", int? crf=None) -> Tensor"); m.def( - "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()"); + "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str device=\"cpu\",int? crf=None) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( @@ -603,9 +603,12 @@ void encode_video_to_file( const at::Tensor& frames, int64_t frame_rate, std::string_view file_name, + std::string_view device = "cpu", std::optional crf = std::nullopt) { VideoStreamOptions videoStreamOptions; videoStreamOptions.crf = crf; + + videoStreamOptions.device = torch::Device(std::string(device)); VideoEncoder( frames, validateInt64ToInt(frame_rate, "frame_rate"), @@ -618,10 +621,13 @@ at::Tensor encode_video_to_tensor( const at::Tensor& frames, int64_t frame_rate, std::string_view format, + std::string_view device = "cpu", std::optional crf = std::nullopt) { auto avioContextHolder = std::make_unique(); VideoStreamOptions videoStreamOptions; videoStreamOptions.crf = crf; + + videoStreamOptions.device = torch::Device(std::string(device)); return VideoEncoder( frames, validateInt64ToInt(frame_rate, "frame_rate"), @@ -636,6 +642,7 @@ void _encode_video_to_file_like( int64_t frame_rate, std::string_view format, int64_t file_like_context, + std::string_view device = "cpu", std::optional crf = std::nullopt) { auto fileLikeContext = reinterpret_cast(file_like_context); @@ -645,6 +652,7 @@ void _encode_video_to_file_like( VideoStreamOptions videoStreamOptions; videoStreamOptions.crf = crf; + videoStreamOptions.device = torch::Device(std::string(device)); VideoEncoder encoder( frames, diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 32995c964..bb6f919a5 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -212,6 +212,7 @@ def encode_video_to_file_like( frame_rate: int, format: str, file_like: Union[io.RawIOBase, io.BufferedIOBase], + device: Optional[str] = "cpu", crf: Optional[int] = None, ) -> None: """Encode video frames to a file-like object. @@ -221,6 +222,7 @@ def encode_video_to_file_like( frame_rate: Frame rate in frames per second format: Video format (e.g., "mp4", "mov", "mkv") file_like: File-like object that supports write() and seek() methods + device: Device to use for encoding (default: "cpu") crf: Optional constant rate factor for encoding quality """ assert _pybind_ops is not None @@ -230,6 +232,7 @@ def encode_video_to_file_like( frame_rate, format, _pybind_ops.create_file_like_context(file_like, True), # True means for writing + device, crf, ) @@ -318,7 +321,8 @@ def encode_video_to_file_abstract( frames: torch.Tensor, frame_rate: int, filename: str, - crf: Optional[int], + device: str = "cpu", + crf: Optional[int] = None, ) -> None: return @@ -328,7 +332,8 @@ def encode_video_to_tensor_abstract( frames: torch.Tensor, frame_rate: int, format: str, - crf: Optional[int], + device: str = "cpu", + crf: Optional[int] = None, ) -> torch.Tensor: return torch.empty([], dtype=torch.long) @@ -339,6 +344,7 @@ def _encode_video_to_file_like_abstract( frame_rate: int, format: str, file_like_context: int, + device: str = "cpu", crf: Optional[int] = None, ) -> None: return diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index f6a725278..1957e08e8 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -1,8 +1,8 @@ from pathlib import Path -from typing import Union +from typing import Optional, Union import torch -from torch import Tensor +from torch import device as torch_device, Tensor from torchcodec import _core @@ -16,9 +16,18 @@ class VideoEncoder: C is 3 channels (RGB), H is height, and W is width. Values must be uint8 in the range ``[0, 255]``. frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate. + device (str or torch.device, optional): The device to use for encoding. Default: "cpu". + If you pass a CUDA device, frames will be encoded on GPU. + Note: The "beta" CUDA backend is not supported for encoding. """ - def __init__(self, frames: Tensor, *, frame_rate: int): + def __init__( + self, + frames: Tensor, + *, + frame_rate: int, + device: Optional[Union[str, torch_device]] = "cpu", + ): torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder") if not isinstance(frames, Tensor): raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.") @@ -29,8 +38,13 @@ def __init__(self, frames: Tensor, *, frame_rate: int): if frame_rate <= 0: raise ValueError(f"{frame_rate = } must be > 0.") + # Validate and store device + if isinstance(device, torch_device): + device = str(device) + self._frames = frames self._frame_rate = frame_rate + self._device = device def to_file( self, @@ -47,6 +61,7 @@ def to_file( frames=self._frames, frame_rate=self._frame_rate, filename=str(dest), + device=self._device, ) def to_tensor( @@ -66,6 +81,7 @@ def to_tensor( frames=self._frames, frame_rate=self._frame_rate, format=format, + device=self._device, ) def to_file_like( @@ -89,4 +105,5 @@ def to_file_like( frame_rate=self._frame_rate, format=format, file_like=file_like, + device=self._device, ) diff --git a/test/test_encoders.py b/test/test_encoders.py index b7223c88a..193a72f18 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -11,6 +11,7 @@ import torch from torchcodec.decoders import AudioDecoder +from torchcodec.decoders._video_decoder import VideoDecoder from torchcodec.encoders import AudioEncoder, VideoEncoder from .utils import ( @@ -567,6 +568,10 @@ def write(self, data): class TestVideoEncoder: + + def decode(self, source=None) -> torch.Tensor: + return VideoDecoder(source).get_frames_in_range(start=0, stop=60) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_bad_input_parameterized(self, tmp_path, method): if method == "to_file": @@ -630,12 +635,15 @@ def test_bad_input(self, tmp_path): encoder.to_tensor(format="bad_format") @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) - def test_contiguity(self, method, tmp_path): + @pytest.mark.parametrize( + "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + ) + def test_contiguity(self, method, tmp_path, device): # Ensure that 2 sets of video frames with the same pixel values are encoded # in the same way, regardless of their memory layout. Here we encode 2 equal # frame tensors, one is contiguous while the other is non-contiguous. - num_frames, channels, height, width = 5, 3, 64, 64 + num_frames, channels, height, width = 5, 3, 256, 256 contiguous_frames = torch.randint( 0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8 ).contiguous() @@ -656,14 +664,16 @@ def test_contiguity(self, method, tmp_path): def encode_to_tensor(frames): if method == "to_file": dest = str(tmp_path / "output.mp4") - VideoEncoder(frames, frame_rate=30).to_file(dest=dest) + VideoEncoder(frames, frame_rate=30, device=device).to_file(dest=dest) with open(dest, "rb") as f: - return torch.frombuffer(f.read(), dtype=torch.uint8) + return torch.frombuffer(f.read(), dtype=torch.uint8).clone() elif method == "to_tensor": - return VideoEncoder(frames, frame_rate=30).to_tensor(format="mp4") + return VideoEncoder(frames, frame_rate=30, device=device).to_tensor( + format="mp4" + ) elif method == "to_file_like": file_like = io.BytesIO() - VideoEncoder(frames, frame_rate=30).to_file_like( + VideoEncoder(frames, frame_rate=30, device=device).to_file_like( file_like, format="mp4" ) return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8) diff --git a/test/test_ops.py b/test/test_ops.py index e798a7a2b..39e8f4291 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1159,12 +1159,15 @@ def decode(self, source=None) -> torch.Tensor: "format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow)) ) @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) - def test_video_encoder_round_trip(self, tmp_path, format, method): + @pytest.mark.parametrize( + "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + ) + def test_video_encoder_round_trip(self, tmp_path, format, method, device): # Test that decode(encode(decode(frames))) == decode(frames) ffmpeg_version = get_ffmpeg_major_version() # In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm. # As a result, we skip the round trip test. - if ffmpeg_version == 6 and format != "webm": + if ffmpeg_version == 6 and format != "webm" and device == "cpu": pytest.skip( f"FFmpeg6 defaults to lossy encoding for {format}, skipping round-trip test." ) @@ -1172,11 +1175,20 @@ def test_video_encoder_round_trip(self, tmp_path, format, method): ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) ): pytest.skip("Codec for webm is not available in this FFmpeg installation.") + if device == "cuda": + if format not in ("mp4", "mov", "mkv"): + pytest.skip( + f"No NVENC encoder available for format {format}, skipping test." + ) + if ffmpeg_version == 4: + pytest.skip( + "CUDA encoding on FFmpeg 4 results in lower quality, skipping round trip test." + ) + source_frames = self.decode(TEST_SRC_2_720P.path).data - params = dict( - frame_rate=30, crf=0 - ) # Frame rate is fixed with num frames decoded + # Frame rate is fixed with num frames decoded + params = dict(frame_rate=30, crf=0, device=device) if method == "to_file": encoded_path = str(tmp_path / f"encoder_output.{format}") encode_video_to_file( @@ -1205,17 +1217,15 @@ def test_video_encoder_round_trip(self, tmp_path, format, method): assert source_frames.shape == round_trip_frames.shape assert source_frames.dtype == round_trip_frames.dtype - # If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels - # are within a higher tolerance. - if ffmpeg_version == 6: - assert_close = partial(assert_tensor_close_on_at_least, percentage=99) - atol = 15 + # If encoding on GPU, assert 99% of pixels are within a strict tolerance. + if device == "cuda": + assert_close = partial( + assert_tensor_close_on_at_least, atol=3, percentage=99 + ) else: - assert_close = torch.testing.assert_close - atol = 2 + assert_close = partial(torch.testing.assert_close, atol=2, rtol=0) for s_frame, rt_frame in zip(source_frames, round_trip_frames): - assert psnr(s_frame, rt_frame) > 30 - assert_close(s_frame, rt_frame, atol=atol, rtol=0) + assert_close(s_frame, rt_frame) @pytest.mark.parametrize( "format", @@ -1230,7 +1240,10 @@ def test_video_encoder_round_trip(self, tmp_path, format, method): ), ) @pytest.mark.parametrize("method", ("to_tensor", "to_file_like")) - def test_against_to_file(self, tmp_path, format, method): + @pytest.mark.parametrize( + "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + ) + def test_against_to_file(self, tmp_path, format, method, device): # Test that to_file, to_tensor, and to_file_like produce the same results ffmpeg_version = get_ffmpeg_major_version() if format == "webm" and ( @@ -1239,7 +1252,7 @@ def test_against_to_file(self, tmp_path, format, method): pytest.skip("Codec for webm is not available in this FFmpeg installation.") source_frames = self.decode(TEST_SRC_2_720P.path).data - params = dict(frame_rate=30, crf=0) + params = dict(frame_rate=30, crf=0, device=device) encoded_file = tmp_path / f"output.{format}" encode_video_to_file(frames=source_frames, filename=str(encoded_file), **params) @@ -1278,13 +1291,27 @@ def test_against_to_file(self, tmp_path, format, method): pytest.param("webm", marks=pytest.mark.slow), ), ) - def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): + @pytest.mark.parametrize( + "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + ) + def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format, device): ffmpeg_version = get_ffmpeg_major_version() if format == "webm" and ( ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) ): pytest.skip("Codec for webm is not available in this FFmpeg installation.") + # Pass flag to FFmpeg CLI to NVENC encoder when device is CUDA and format has a compatible NVENC codec + if device == "cuda": + if format not in ("mp4", "mov", "mkv"): + pytest.skip( + f"No NVENC encoder available for format {format}, skipping test." + ) + if ffmpeg_version == 4: + pytest.skip( + "CUDA encoding on FFmpeg 4 results in lower quality, skipping round trip test." + ) + source_frames = self.decode(TEST_SRC_2_720P.path).data # Encode with FFmpeg CLI @@ -1295,7 +1322,7 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}") frame_rate = 30 crf = 0 - # Some codecs (ex. MPEG4) do not support CRF. + # Some codecs (ex. MPEG4) and CUDA backend codecs do not support CRF. # Flags not supported by the selected codec will be ignored. ffmpeg_cmd = [ "ffmpeg", @@ -1310,10 +1337,14 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): str(frame_rate), "-i", temp_raw_path, - "-crf", - str(crf), - ffmpeg_encoded_path, ] + + if device == "cuda": + ffmpeg_cmd.extend(["-c:v", "h264_nvenc"]) + quality_param = "qp" if device == "cuda" else "crf" + + ffmpeg_cmd.extend([f"-{quality_param}", str(crf)]) + ffmpeg_cmd.extend([ffmpeg_encoded_path]) subprocess.run(ffmpeg_cmd, check=True) # Encode with our video encoder @@ -1322,6 +1353,7 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): frames=source_frames, frame_rate=frame_rate, filename=encoder_output_path, + device=device, crf=crf, ) @@ -1337,13 +1369,15 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): # Check that PSNR between both encoded versions is high for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames): - res = psnr(ff_frame, enc_frame) - assert res > 30 + assert psnr(ff_frame, enc_frame) > 30 assert_tensor_close_on_at_least( ff_frame, enc_frame, percentage=percentage, atol=2 ) - def test_to_file_like_custom_file_object(self): + @pytest.mark.parametrize( + "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + ) + def test_to_file_like_custom_file_object(self, device): """Test to_file_like with a custom file-like object that implements write and seek.""" class CustomFileObject: @@ -1362,34 +1396,47 @@ def get_encoded_data(self): source_frames = self.decode(TEST_SRC_2_720P.path).data file_like = CustomFileObject() encode_video_to_file_like( - source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like + source_frames, + frame_rate=30, + crf=0, + format="mp4", + file_like=file_like, + device=device, ) decoded_samples = self.decode(file_like.get_encoded_data()) - torch.testing.assert_close( - decoded_samples.data, - source_frames, - atol=2, - rtol=0, - ) + if device == "cuda": + assert_close = assert_frames_equal + else: + assert_close = partial(torch.testing.assert_close, atol=2) + + assert_close(decoded_samples.data, source_frames) - def test_to_file_like_real_file(self, tmp_path): + @pytest.mark.parametrize( + "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + ) + def test_to_file_like_real_file(self, tmp_path, device): """Test to_file_like with a real file opened in binary write mode.""" source_frames = self.decode(TEST_SRC_2_720P.path).data file_path = tmp_path / "test_file_like.mp4" with open(file_path, "wb") as file_like: encode_video_to_file_like( - source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like + source_frames, + frame_rate=30, + crf=0, + format="mp4", + file_like=file_like, + device=device, ) decoded_samples = self.decode(str(file_path)) - torch.testing.assert_close( - decoded_samples.data, - source_frames, - atol=2, - rtol=0, - ) + if device == "cuda": + assert_close = assert_frames_equal + else: + assert_close = partial(torch.testing.assert_close, atol=2) + + assert_close(decoded_samples.data, source_frames) def test_to_file_like_bad_methods(self): source_frames = self.decode(TEST_SRC_2_720P.path).data