Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include "Encoder.h"
#include "torch/types.h"

extern "C" {
#include <libavutil/pixdesc.h>
}

namespace facebook::torchcodec {

namespace {
Expand Down Expand Up @@ -635,15 +639,20 @@ void VideoEncoder::initializeEncoder(
outWidth_ = inWidth_;
outHeight_ = inHeight_;

// TODO-VideoEncoder: Enable other pixel formats
// Let FFmpeg choose best pixel format to minimize loss
outPixelFormat_ = avcodec_find_best_pix_fmt_of_list(
getSupportedPixelFormats(*avCodec), // List of supported formats
AV_PIX_FMT_GBRP, // We reorder input to GBRP currently
0, // No alpha channel
nullptr // Discard conversion loss information
);
TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt")
if (videoStreamOptions.pixelFormat.has_value()) {
outPixelFormat_ =
av_get_pix_fmt(videoStreamOptions.pixelFormat.value().c_str());
TORCH_CHECK(
outPixelFormat_ != AV_PIX_FMT_NONE,
"Unknown pixel format: ",
videoStreamOptions.pixelFormat.value());
} else {
const AVPixelFormat* formats = getSupportedPixelFormats(*avCodec);
// If pixel formats are undefined for some reason, try yuv420p
outPixelFormat_ = (formats && formats[0] != AV_PIX_FMT_NONE)
? formats[0]
: AV_PIX_FMT_YUV420P;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getSupportedPixelFormats is not guaranteed to return any formats. If the user does not specify a format and we find none, I think we should try to use the broadly supported yuv420p, rather than error out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, this makes sense and that's similar to what we do for the audio encoder when we can't validate:

// Can't really validate anything in this case, best we can do is hope that
// FLTP is supported by the encoder. If not, FFmpeg will raise.
return AV_SAMPLE_FMT_FLTP;


// Configure codec parameters
avCodecContext_->codec_id = avCodec->id;
Expand Down
4 changes: 4 additions & 0 deletions src/torchcodec/_core/StreamOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ struct VideoStreamOptions {
// TODO-VideoEncoder: Consider adding other optional fields here
// (bit rate, gop size, max b frames, preset)
std::optional<int> crf;

// Optional pixel format for video encoding (e.g., "yuv420p", "yuv444p")
// If not specified, uses codec's default format.
std::optional<std::string> pixelFormat;
};

struct AudioStreamOptions {
Expand Down
12 changes: 9 additions & 3 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
m.def(
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? pixel_format=None, int? crf=None) -> ()");
m.def(
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? pixel_format=None, int? crf=None) -> Tensor");
m.def(
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? pixel_format=None, int? crf=None) -> ()");
m.def(
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
m.def(
Expand Down Expand Up @@ -603,8 +603,10 @@ void encode_video_to_file(
const at::Tensor& frames,
int64_t frame_rate,
std::string_view file_name,
std::optional<std::string> pixel_format = std::nullopt,
std::optional<int64_t> crf = std::nullopt) {
VideoStreamOptions videoStreamOptions;
videoStreamOptions.pixelFormat = pixel_format;
videoStreamOptions.crf = crf;
VideoEncoder(
frames,
Expand All @@ -618,9 +620,11 @@ at::Tensor encode_video_to_tensor(
const at::Tensor& frames,
int64_t frame_rate,
std::string_view format,
std::optional<std::string> pixel_format = std::nullopt,
std::optional<int64_t> crf = std::nullopt) {
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
VideoStreamOptions videoStreamOptions;
videoStreamOptions.pixelFormat = pixel_format;
videoStreamOptions.crf = crf;
return VideoEncoder(
frames,
Expand All @@ -636,6 +640,7 @@ void _encode_video_to_file_like(
int64_t frame_rate,
std::string_view format,
int64_t file_like_context,
std::optional<std::string> pixel_format = std::nullopt,
std::optional<int64_t> crf = std::nullopt) {
auto fileLikeContext =
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
Expand All @@ -644,6 +649,7 @@ void _encode_video_to_file_like(
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);

VideoStreamOptions videoStreamOptions;
videoStreamOptions.pixelFormat = pixel_format;
videoStreamOptions.crf = crf;

VideoEncoder encoder(
Expand Down
6 changes: 6 additions & 0 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def encode_video_to_file_like(
format: str,
file_like: Union[io.RawIOBase, io.BufferedIOBase],
crf: Optional[int] = None,
pixel_format: Optional[str] = None,
) -> None:
"""Encode video frames to a file-like object.

Expand All @@ -222,6 +223,7 @@ def encode_video_to_file_like(
format: Video format (e.g., "mp4", "mov", "mkv")
file_like: File-like object that supports write() and seek() methods
crf: Optional constant rate factor for encoding quality
pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p")
"""
assert _pybind_ops is not None

Expand All @@ -230,6 +232,7 @@ def encode_video_to_file_like(
frame_rate,
format,
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
pixel_format,
crf,
)

Expand Down Expand Up @@ -319,6 +322,7 @@ def encode_video_to_file_abstract(
frame_rate: int,
filename: str,
crf: Optional[int],
pixel_format: Optional[str],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing = None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps - I'm not sure if they are needed for the @register_fake annotated functions. I'll add them in case.

) -> None:
return

Expand All @@ -329,6 +333,7 @@ def encode_video_to_tensor_abstract(
frame_rate: int,
format: str,
crf: Optional[int],
pixel_format: Optional[str],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing = None?

) -> torch.Tensor:
return torch.empty([], dtype=torch.long)

Expand All @@ -340,6 +345,7 @@ def _encode_video_to_file_like_abstract(
format: str,
file_like_context: int,
crf: Optional[int] = None,
pixel_format: Optional[str] = None,
) -> None:
return

Expand Down
17 changes: 16 additions & 1 deletion src/torchcodec/encoders/_video_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Union
from typing import Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -35,29 +35,38 @@ def __init__(self, frames: Tensor, *, frame_rate: int):
def to_file(
self,
dest: Union[str, Path],
*,
pixel_format: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job on making this a keyword only params 👍

) -> None:
"""Encode frames into a file.

Args:
dest (str or ``pathlib.Path``): The path to the output file, e.g.
``video.mp4``. The extension of the file determines the video
container format.
pixel_format (str, optional): The pixel format for encoding (e.g.,
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
"""
_core.encode_video_to_file(
frames=self._frames,
frame_rate=self._frame_rate,
filename=str(dest),
pixel_format=pixel_format,
)

def to_tensor(
self,
format: str,
*,
pixel_format: Optional[str] = None,
) -> Tensor:
"""Encode frames into raw bytes, as a 1D uint8 Tensor.

Args:
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
"mkv", "avi", "webm", "flv", or "gif"
pixel_format (str, optional): The pixel format to encode frames into (e.g.,
"yuv420p", "yuv444p"). If not specified, uses codec's default format.

Returns:
Tensor: The raw encoded bytes as 4D uint8 Tensor.
Expand All @@ -66,12 +75,15 @@ def to_tensor(
frames=self._frames,
frame_rate=self._frame_rate,
format=format,
pixel_format=pixel_format,
)

def to_file_like(
self,
file_like,
format: str,
*,
pixel_format: Optional[str] = None,
) -> None:
"""Encode frames into a file-like object.

Expand All @@ -83,10 +95,13 @@ def to_file_like(
int = 0) -> int``.
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
"mkv", "avi", "webm", "flv", or "gif".
pixel_format (str, optional): The pixel format for encoding (e.g.,
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
"""
_core.encode_video_to_file_like(
frames=self._frames,
frame_rate=self._frame_rate,
format=format,
file_like=file_like,
pixel_format=pixel_format,
)
38 changes: 23 additions & 15 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,21 +1162,14 @@ def decode(self, source=None) -> torch.Tensor:
def test_video_encoder_round_trip(self, tmp_path, format, method):
# Test that decode(encode(decode(frames))) == decode(frames)
ffmpeg_version = get_ffmpeg_major_version()
# In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm.
# As a result, we skip the round trip test.
if ffmpeg_version == 6 and format != "webm":
pytest.skip(
f"FFmpeg6 defaults to lossy encoding for {format}, skipping round-trip test."
)
if format == "webm" and (
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
):
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
source_frames = self.decode(TEST_SRC_2_720P.path).data

params = dict(
frame_rate=30, crf=0
) # Frame rate is fixed with num frames decoded
# Frame rate is fixed with num frames decoded
params = dict(frame_rate=30, pixel_format="yuv444p", crf=0)
if method == "to_file":
encoded_path = str(tmp_path / f"encoder_output.{format}")
encode_video_to_file(
Expand Down Expand Up @@ -1212,7 +1205,7 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
atol = 15
else:
assert_close = torch.testing.assert_close
atol = 2
atol = 3 if format == "webm" else 2
for s_frame, rt_frame in zip(source_frames, round_trip_frames):
assert psnr(s_frame, rt_frame) > 30
assert_close(s_frame, rt_frame, atol=atol, rtol=0)
Expand Down Expand Up @@ -1274,16 +1267,18 @@ def test_against_to_file(self, tmp_path, format, method):
"avi",
"mkv",
"flv",
"gif",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gif only supports rgb pixel formats, this test is now focused on the more common yuv formats.

pytest.param("webm", marks=pytest.mark.slow),
),
)
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
@pytest.mark.parametrize("pixel_format", ("yuv444p", "yuv420p"))
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format, pixel_format):
ffmpeg_version = get_ffmpeg_major_version()
if format == "webm" and (
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
):
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
if format in ("avi", "flv") and pixel_format == "yuv444p":
pytest.skip(f"Default codec for {format} does not support {pixel_format}")

source_frames = self.decode(TEST_SRC_2_720P.path).data

Expand All @@ -1303,13 +1298,15 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
"-f",
"rawvideo",
"-pix_fmt",
"rgb24",
"rgb24", # Input format
"-s",
f"{source_frames.shape[3]}x{source_frames.shape[2]}",
"-r",
str(frame_rate),
"-i",
temp_raw_path,
"-pix_fmt",
pixel_format, # Output format
"-crf",
str(crf),
ffmpeg_encoded_path,
Expand All @@ -1322,6 +1319,7 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
frames=source_frames,
frame_rate=frame_rate,
filename=encoder_output_path,
pixel_format=pixel_format,
crf=crf,
)

Expand Down Expand Up @@ -1362,7 +1360,12 @@ def get_encoded_data(self):
source_frames = self.decode(TEST_SRC_2_720P.path).data
file_like = CustomFileObject()
encode_video_to_file_like(
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
source_frames,
frame_rate=30,
pixel_format="yuv444p",
crf=0,
format="mp4",
file_like=file_like,
)
decoded_samples = self.decode(file_like.get_encoded_data())

Expand All @@ -1380,7 +1383,12 @@ def test_to_file_like_real_file(self, tmp_path):

with open(file_path, "wb") as file_like:
encode_video_to_file_like(
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
source_frames,
frame_rate=30,
pixel_format="yuv444p",
crf=0,
format="mp4",
file_like=file_like,
)
decoded_samples = self.decode(str(file_path))

Expand Down
Loading