Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/torchcodec/encoders/_video_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def to_file(
dest: Union[str, Path],
*,
pixel_format: Optional[str] = None,
crf: Optional[int] = None,
) -> None:
"""Encode frames into a file.

Expand All @@ -46,27 +47,35 @@ def to_file(
container format.
pixel_format (str, optional): The pixel format for encoding (e.g.,
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
mean better quality. Valid range depends on the encoder (commonly 0-51).
Defaults to None (which will use encoder's default).
"""
_core.encode_video_to_file(
frames=self._frames,
frame_rate=self._frame_rate,
filename=str(dest),
pixel_format=pixel_format,
crf=crf,
)

def to_tensor(
self,
format: str,
*,
pixel_format: Optional[str] = None,
crf: Optional[int] = None,
) -> Tensor:
"""Encode frames into raw bytes, as a 1D uint8 Tensor.

Args:
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
"mkv", "avi", "webm", "flv", or "gif"
"mkv", "avi", "webm", "flv", etc.
pixel_format (str, optional): The pixel format to encode frames into (e.g.,
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
mean better quality. Valid range depends on the encoder (commonly 0-51).
Defaults to None (which will use encoder's default).

Returns:
Tensor: The raw encoded bytes as 4D uint8 Tensor.
Expand All @@ -76,6 +85,7 @@ def to_tensor(
frame_rate=self._frame_rate,
format=format,
pixel_format=pixel_format,
crf=crf,
)

def to_file_like(
Expand All @@ -84,6 +94,7 @@ def to_file_like(
format: str,
*,
pixel_format: Optional[str] = None,
crf: Optional[int] = None,
) -> None:
"""Encode frames into a file-like object.

Expand All @@ -94,14 +105,18 @@ def to_file_like(
``write(data: bytes) -> int`` and ``seek(offset: int, whence:
int = 0) -> int``.
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
"mkv", "avi", "webm", "flv", or "gif".
"mkv", "avi", "webm", "flv", etc.
pixel_format (str, optional): The pixel format for encoding (e.g.,
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
mean better quality. Valid range depends on the encoder (commonly 0-51).
Defaults to None (which will use encoder's default).
"""
_core.encode_video_to_file_like(
frames=self._frames,
frame_rate=self._frame_rate,
format=format,
file_like=file_like,
pixel_format=pixel_format,
crf=crf,
)
237 changes: 236 additions & 1 deletion test/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest
import torch
from torchcodec.decoders import AudioDecoder
from torchcodec.decoders import AudioDecoder, VideoDecoder

from torchcodec.encoders import AudioEncoder, VideoEncoder

Expand All @@ -20,7 +20,9 @@
in_fbcode,
IS_WINDOWS,
NASA_AUDIO_MP3,
psnr,
SINE_MONO_S32,
TEST_SRC_2_720P,
TestContainerFile,
)

Expand Down Expand Up @@ -567,6 +569,9 @@ def write(self, data):


class TestVideoEncoder:
def decode(self, source=None) -> torch.Tensor:
return VideoDecoder(source).get_frames_in_range(start=0, stop=60)

@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
def test_bad_input_parameterized(self, tmp_path, method):
if method == "to_file":
Expand Down Expand Up @@ -700,3 +705,233 @@ def encode_to_tensor(frames):
torch.testing.assert_close(
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
)

@pytest.mark.parametrize(
"format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow))
)
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
def test_round_trip(self, tmp_path, format, method):
# Test that decode(encode(decode(frames))) == decode(frames)
ffmpeg_version = get_ffmpeg_major_version()
if format == "webm" and (
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
):
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
source_frames = self.decode(TEST_SRC_2_720P.path).data

# Frame rate is fixed with num frames decoded
encoder = VideoEncoder(frames=source_frames, frame_rate=30)

if method == "to_file":
encoded_path = str(tmp_path / f"encoder_output.{format}")
encoder.to_file(dest=encoded_path, pixel_format="yuv444p", crf=0)
round_trip_frames = self.decode(encoded_path).data
elif method == "to_tensor":
encoded_tensor = encoder.to_tensor(
format=format, pixel_format="yuv444p", crf=0
)
round_trip_frames = self.decode(encoded_tensor).data
elif method == "to_file_like":
file_like = io.BytesIO()
encoder.to_file_like(
file_like=file_like, format=format, pixel_format="yuv444p", crf=0
)
round_trip_frames = self.decode(file_like.getvalue()).data
else:
raise ValueError(f"Unknown method: {method}")

assert source_frames.shape == round_trip_frames.shape
assert source_frames.dtype == round_trip_frames.dtype

for s_frame, rt_frame in zip(source_frames, round_trip_frames):
assert psnr(s_frame, rt_frame) > 30
torch.testing.assert_close(s_frame, rt_frame, atol=2, rtol=0)

@pytest.mark.parametrize(
"format",
(
"mov",
"mp4",
"avi",
"mkv",
"flv",
"gif",
pytest.param("webm", marks=pytest.mark.slow),
),
)
@pytest.mark.parametrize("method", ("to_tensor", "to_file_like"))
def test_against_to_file(self, tmp_path, format, method):
# Test that to_file, to_tensor, and to_file_like produce the same results
ffmpeg_version = get_ffmpeg_major_version()
if format == "webm" and (
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
):
pytest.skip("Codec for webm is not available in this FFmpeg installation.")

source_frames = self.decode(TEST_SRC_2_720P.path).data
encoder = VideoEncoder(frames=source_frames, frame_rate=30)

encoded_file = tmp_path / f"output.{format}"
encoder.to_file(dest=encoded_file, crf=0)

if method == "to_tensor":
encoded_output = encoder.to_tensor(format=format, crf=0)
else: # to_file_like
file_like = io.BytesIO()
encoder.to_file_like(file_like=file_like, format=format, crf=0)
encoded_output = file_like.getvalue()

torch.testing.assert_close(
self.decode(encoded_file).data,
self.decode(encoded_output).data,
atol=0,
rtol=0,
)

@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")
@pytest.mark.parametrize(
"format",
(
"mov",
"mp4",
"avi",
"mkv",
"flv",
pytest.param("webm", marks=pytest.mark.slow),
),
)
@pytest.mark.parametrize("pixel_format", ("yuv444p", "yuv420p"))
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format, pixel_format):
ffmpeg_version = get_ffmpeg_major_version()
if format == "webm" and (
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
):
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
if format in ("avi", "flv") and pixel_format == "yuv444p":
pytest.skip(f"Default codec for {format} does not support {pixel_format}")

source_frames = self.decode(TEST_SRC_2_720P.path).data

# Encode with FFmpeg CLI
temp_raw_path = str(tmp_path / "temp_input.raw")
with open(temp_raw_path, "wb") as f:
f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes())

ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}")
frame_rate = 30
crf = 0
# Some codecs (ex. MPEG4) do not support CRF.
# Flags not supported by the selected codec will be ignored.
ffmpeg_cmd = [
"ffmpeg",
"-y",
"-f",
"rawvideo",
"-pix_fmt",
"rgb24", # Input format
"-s",
f"{source_frames.shape[3]}x{source_frames.shape[2]}",
"-r",
str(frame_rate),
"-i",
temp_raw_path,
"-pix_fmt",
pixel_format, # Output format
"-crf",
str(crf),
ffmpeg_encoded_path,
]
subprocess.run(ffmpeg_cmd, check=True)

# Encode with our video encoder
encoder_output_path = str(tmp_path / f"encoder_output.{format}")
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
encoder.to_file(dest=encoder_output_path, pixel_format=pixel_format, crf=crf)

ffmpeg_frames = self.decode(ffmpeg_encoded_path).data
encoder_frames = self.decode(encoder_output_path).data

assert ffmpeg_frames.shape[0] == encoder_frames.shape[0]

# If FFmpeg selects a codec or pixel format that uses qscale (not crf),
# the VideoEncoder outputs *slightly* different frames.
# There may be additional subtle differences in the encoder.
percentage = 94 if ffmpeg_version == 6 or format == "avi" else 99

# Check that PSNR between both encoded versions is high
for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames):
res = psnr(ff_frame, enc_frame)
assert res > 30
assert_tensor_close_on_at_least(
ff_frame, enc_frame, percentage=percentage, atol=2
)

def test_to_file_like_custom_file_object(self):
"""Test to_file_like with a custom file-like object that implements write and seek."""

class CustomFileObject:
def __init__(self):
self._file = io.BytesIO()

def write(self, data):
return self._file.write(data)

def seek(self, offset, whence=0):
return self._file.seek(offset, whence)

def get_encoded_data(self):
return self._file.getvalue()

source_frames = self.decode(TEST_SRC_2_720P.path).data
encoder = VideoEncoder(frames=source_frames, frame_rate=30)

file_like = CustomFileObject()
encoder.to_file_like(file_like, format="mp4", pixel_format="yuv444p", crf=0)
decoded_frames = self.decode(file_like.get_encoded_data())

torch.testing.assert_close(
decoded_frames.data,
source_frames,
atol=2,
rtol=0,
)

def test_to_file_like_real_file(self, tmp_path):
"""Test to_file_like with a real file opened in binary write mode."""
source_frames = self.decode(TEST_SRC_2_720P.path).data
encoder = VideoEncoder(frames=source_frames, frame_rate=30)

file_path = tmp_path / "test_file_like.mp4"

with open(file_path, "wb") as file_like:
encoder.to_file_like(file_like, format="mp4", pixel_format="yuv444p", crf=0)
decoded_frames = self.decode(str(file_path))

torch.testing.assert_close(
decoded_frames.data,
source_frames,
atol=2,
rtol=0,
)

def test_to_file_like_bad_methods(self):
source_frames = self.decode(TEST_SRC_2_720P.path).data
encoder = VideoEncoder(frames=source_frames, frame_rate=30)

class NoWriteMethod:
def seek(self, offset, whence=0):
return 0

with pytest.raises(
RuntimeError, match="File like object must implement a write method"
):
encoder.to_file_like(NoWriteMethod(), format="mp4")

class NoSeekMethod:
def write(self, data):
return len(data)

with pytest.raises(
RuntimeError, match="File like object must implement a seek method"
):
encoder.to_file_like(NoSeekMethod(), format="mp4")
Loading
Loading