diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 9546924c1..239b4c828 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -4,6 +4,10 @@ #include "Encoder.h" #include "torch/types.h" +extern "C" { +#include +} + namespace facebook::torchcodec { namespace { @@ -534,6 +538,36 @@ torch::Tensor validateFrames(const torch::Tensor& frames) { return frames.contiguous(); } +AVPixelFormat validatePixelFormat( + const AVCodec& avCodec, + const std::string& targetPixelFormat) { + AVPixelFormat pixelFormat = av_get_pix_fmt(targetPixelFormat.c_str()); + + // Validate that the encoder supports this pixel format + const AVPixelFormat* supportedFormats = getSupportedPixelFormats(avCodec); + if (supportedFormats != nullptr) { + for (int i = 0; supportedFormats[i] != AV_PIX_FMT_NONE; ++i) { + if (supportedFormats[i] == pixelFormat) { + return pixelFormat; + } + } + } + + std::stringstream errorMsg; + // av_get_pix_fmt failed to find a pix_fmt + if (pixelFormat == AV_PIX_FMT_NONE) { + errorMsg << "Unknown pixel format: " << targetPixelFormat; + } else { + errorMsg << "Specified pixel format " << targetPixelFormat + << " is not supported by the " << avCodec.name << " encoder."; + } + // Build error message, similar to FFmpeg's error log + errorMsg << "\nSupported pixel formats for " << avCodec.name << ":"; + for (int i = 0; supportedFormats[i] != AV_PIX_FMT_NONE; ++i) { + errorMsg << " " << av_get_pix_fmt_name(supportedFormats[i]); + } + TORCH_CHECK(false, errorMsg.str()); +} } // namespace VideoEncoder::~VideoEncoder() { @@ -635,15 +669,19 @@ void VideoEncoder::initializeEncoder( outWidth_ = inWidth_; outHeight_ = inHeight_; - // TODO-VideoEncoder: Enable other pixel formats - // Let FFmpeg choose best pixel format to minimize loss - outPixelFormat_ = avcodec_find_best_pix_fmt_of_list( - getSupportedPixelFormats(*avCodec), // List of supported formats - AV_PIX_FMT_GBRP, // We reorder input to GBRP currently - 0, // No alpha channel - nullptr // Discard conversion loss information - ); - TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt") + if (videoStreamOptions.pixelFormat.has_value()) { + outPixelFormat_ = + validatePixelFormat(*avCodec, videoStreamOptions.pixelFormat.value()); + } else { + const AVPixelFormat* formats = getSupportedPixelFormats(*avCodec); + // Use first listed pixel format as default (often yuv420p). + // This is similar to FFmpeg's logic: + // https://www.ffmpeg.org/doxygen/4.0/decode_8c_source.html#l01087 + // If pixel formats are undefined for some reason, try yuv420p + outPixelFormat_ = (formats && formats[0] != AV_PIX_FMT_NONE) + ? formats[0] + : AV_PIX_FMT_YUV420P; + } // Configure codec parameters avCodecContext_->codec_id = avCodec->id; diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index e5ab256e1..b7647176c 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -48,6 +48,10 @@ struct VideoStreamOptions { // TODO-VideoEncoder: Consider adding other optional fields here // (bit rate, gop size, max b frames, preset) std::optional crf; + + // Optional pixel format for video encoding (e.g., "yuv420p", "yuv444p") + // If not specified, uses codec's default format. + std::optional pixelFormat; }; struct AudioStreamOptions { diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index b78160f6c..b4320a24d 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); m.def( - "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()"); + "encode_video_to_file(Tensor frames, int frame_rate, str filename, str? pixel_format=None, int? crf=None) -> ()"); m.def( - "encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor"); + "encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? pixel_format=None, int? crf=None) -> Tensor"); m.def( - "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()"); + "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? pixel_format=None, int? crf=None) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( @@ -603,8 +603,10 @@ void encode_video_to_file( const at::Tensor& frames, int64_t frame_rate, std::string_view file_name, + std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt) { VideoStreamOptions videoStreamOptions; + videoStreamOptions.pixelFormat = pixel_format; videoStreamOptions.crf = crf; VideoEncoder( frames, @@ -618,9 +620,11 @@ at::Tensor encode_video_to_tensor( const at::Tensor& frames, int64_t frame_rate, std::string_view format, + std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt) { auto avioContextHolder = std::make_unique(); VideoStreamOptions videoStreamOptions; + videoStreamOptions.pixelFormat = pixel_format; videoStreamOptions.crf = crf; return VideoEncoder( frames, @@ -636,6 +640,7 @@ void _encode_video_to_file_like( int64_t frame_rate, std::string_view format, int64_t file_like_context, + std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt) { auto fileLikeContext = reinterpret_cast(file_like_context); @@ -644,6 +649,7 @@ void _encode_video_to_file_like( std::unique_ptr avioContextHolder(fileLikeContext); VideoStreamOptions videoStreamOptions; + videoStreamOptions.pixelFormat = pixel_format; videoStreamOptions.crf = crf; VideoEncoder encoder( diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 32995c964..2fb73ef14 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -213,6 +213,7 @@ def encode_video_to_file_like( format: str, file_like: Union[io.RawIOBase, io.BufferedIOBase], crf: Optional[int] = None, + pixel_format: Optional[str] = None, ) -> None: """Encode video frames to a file-like object. @@ -222,6 +223,7 @@ def encode_video_to_file_like( format: Video format (e.g., "mp4", "mov", "mkv") file_like: File-like object that supports write() and seek() methods crf: Optional constant rate factor for encoding quality + pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p") """ assert _pybind_ops is not None @@ -230,6 +232,7 @@ def encode_video_to_file_like( frame_rate, format, _pybind_ops.create_file_like_context(file_like, True), # True means for writing + pixel_format, crf, ) @@ -318,7 +321,8 @@ def encode_video_to_file_abstract( frames: torch.Tensor, frame_rate: int, filename: str, - crf: Optional[int], + crf: Optional[int] = None, + pixel_format: Optional[str] = None, ) -> None: return @@ -328,7 +332,8 @@ def encode_video_to_tensor_abstract( frames: torch.Tensor, frame_rate: int, format: str, - crf: Optional[int], + crf: Optional[int] = None, + pixel_format: Optional[str] = None, ) -> torch.Tensor: return torch.empty([], dtype=torch.long) @@ -340,6 +345,7 @@ def _encode_video_to_file_like_abstract( format: str, file_like_context: int, crf: Optional[int] = None, + pixel_format: Optional[str] = None, ) -> None: return diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index f6a725278..e0630d012 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Union +from typing import Optional, Union import torch from torch import Tensor @@ -35,6 +35,8 @@ def __init__(self, frames: Tensor, *, frame_rate: int): def to_file( self, dest: Union[str, Path], + *, + pixel_format: Optional[str] = None, ) -> None: """Encode frames into a file. @@ -42,22 +44,29 @@ def to_file( dest (str or ``pathlib.Path``): The path to the output file, e.g. ``video.mp4``. The extension of the file determines the video container format. + pixel_format (str, optional): The pixel format for encoding (e.g., + "yuv420p", "yuv444p"). If not specified, uses codec's default format. """ _core.encode_video_to_file( frames=self._frames, frame_rate=self._frame_rate, filename=str(dest), + pixel_format=pixel_format, ) def to_tensor( self, format: str, + *, + pixel_format: Optional[str] = None, ) -> Tensor: """Encode frames into raw bytes, as a 1D uint8 Tensor. Args: format (str): The container format of the encoded frames, e.g. "mp4", "mov", "mkv", "avi", "webm", "flv", or "gif" + pixel_format (str, optional): The pixel format to encode frames into (e.g., + "yuv420p", "yuv444p"). If not specified, uses codec's default format. Returns: Tensor: The raw encoded bytes as 4D uint8 Tensor. @@ -66,12 +75,15 @@ def to_tensor( frames=self._frames, frame_rate=self._frame_rate, format=format, + pixel_format=pixel_format, ) def to_file_like( self, file_like, format: str, + *, + pixel_format: Optional[str] = None, ) -> None: """Encode frames into a file-like object. @@ -83,10 +95,13 @@ def to_file_like( int = 0) -> int``. format (str): The container format of the encoded frames, e.g. "mp4", "mov", "mkv", "avi", "webm", "flv", or "gif". + pixel_format (str, optional): The pixel format for encoding (e.g., + "yuv420p", "yuv444p"). If not specified, uses codec's default format. """ _core.encode_video_to_file_like( frames=self._frames, frame_rate=self._frame_rate, format=format, file_like=file_like, + pixel_format=pixel_format, ) diff --git a/test/test_encoders.py b/test/test_encoders.py index b7223c88a..922b67bbb 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -629,6 +629,30 @@ def test_bad_input(self, tmp_path): ): encoder.to_tensor(format="bad_format") + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + def test_pixel_format_errors(self, method, tmp_path): + frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8) + encoder = VideoEncoder(frames, frame_rate=30) + + if method == "to_file": + valid_params = dict(dest=str(tmp_path / "output.mp4")) + elif method == "to_tensor": + valid_params = dict(format="mp4") + elif method == "to_file_like": + valid_params = dict(file_like=io.BytesIO(), format="mp4") + + with pytest.raises( + RuntimeError, + match=r"Unknown pixel format: invalid_pix_fmt[\s\S]*Supported pixel formats.*yuv420p", + ): + getattr(encoder, method)(**valid_params, pixel_format="invalid_pix_fmt") + + with pytest.raises( + RuntimeError, + match=r"Specified pixel format rgb24 is not supported[\s\S]*Supported pixel formats.*yuv420p", + ): + getattr(encoder, method)(**valid_params, pixel_format="rgb24") + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_contiguity(self, method, tmp_path): # Ensure that 2 sets of video frames with the same pixel values are encoded diff --git a/test/test_ops.py b/test/test_ops.py index e798a7a2b..bb6ce601b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1162,21 +1162,14 @@ def decode(self, source=None) -> torch.Tensor: def test_video_encoder_round_trip(self, tmp_path, format, method): # Test that decode(encode(decode(frames))) == decode(frames) ffmpeg_version = get_ffmpeg_major_version() - # In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm. - # As a result, we skip the round trip test. - if ffmpeg_version == 6 and format != "webm": - pytest.skip( - f"FFmpeg6 defaults to lossy encoding for {format}, skipping round-trip test." - ) if format == "webm" and ( ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) ): pytest.skip("Codec for webm is not available in this FFmpeg installation.") source_frames = self.decode(TEST_SRC_2_720P.path).data - params = dict( - frame_rate=30, crf=0 - ) # Frame rate is fixed with num frames decoded + # Frame rate is fixed with num frames decoded + params = dict(frame_rate=30, pixel_format="yuv444p", crf=0) if method == "to_file": encoded_path = str(tmp_path / f"encoder_output.{format}") encode_video_to_file( @@ -1212,7 +1205,7 @@ def test_video_encoder_round_trip(self, tmp_path, format, method): atol = 15 else: assert_close = torch.testing.assert_close - atol = 2 + atol = 3 if format == "webm" else 2 for s_frame, rt_frame in zip(source_frames, round_trip_frames): assert psnr(s_frame, rt_frame) > 30 assert_close(s_frame, rt_frame, atol=atol, rtol=0) @@ -1274,16 +1267,18 @@ def test_against_to_file(self, tmp_path, format, method): "avi", "mkv", "flv", - "gif", pytest.param("webm", marks=pytest.mark.slow), ), ) - def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): + @pytest.mark.parametrize("pixel_format", ("yuv444p", "yuv420p")) + def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format, pixel_format): ffmpeg_version = get_ffmpeg_major_version() if format == "webm" and ( ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) ): pytest.skip("Codec for webm is not available in this FFmpeg installation.") + if format in ("avi", "flv") and pixel_format == "yuv444p": + pytest.skip(f"Default codec for {format} does not support {pixel_format}") source_frames = self.decode(TEST_SRC_2_720P.path).data @@ -1303,13 +1298,15 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): "-f", "rawvideo", "-pix_fmt", - "rgb24", + "rgb24", # Input format "-s", f"{source_frames.shape[3]}x{source_frames.shape[2]}", "-r", str(frame_rate), "-i", temp_raw_path, + "-pix_fmt", + pixel_format, # Output format "-crf", str(crf), ffmpeg_encoded_path, @@ -1322,6 +1319,7 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): frames=source_frames, frame_rate=frame_rate, filename=encoder_output_path, + pixel_format=pixel_format, crf=crf, ) @@ -1362,7 +1360,12 @@ def get_encoded_data(self): source_frames = self.decode(TEST_SRC_2_720P.path).data file_like = CustomFileObject() encode_video_to_file_like( - source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like + source_frames, + frame_rate=30, + pixel_format="yuv444p", + crf=0, + format="mp4", + file_like=file_like, ) decoded_samples = self.decode(file_like.get_encoded_data()) @@ -1380,7 +1383,12 @@ def test_to_file_like_real_file(self, tmp_path): with open(file_path, "wb") as file_like: encode_video_to_file_like( - source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like + source_frames, + frame_rate=30, + pixel_format="yuv444p", + crf=0, + format="mp4", + file_like=file_like, ) decoded_samples = self.decode(str(file_path))