Skip to content

Commit 1e7dc34

Browse files
author
Daniel Flores
committed
testing
1 parent 1e06ea5 commit 1e7dc34

File tree

2 files changed

+120
-64
lines changed

2 files changed

+120
-64
lines changed

test/test_encoders.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from torchcodec.decoders import AudioDecoder
1313

14-
from torchcodec.encoders import AudioEncoder
14+
from torchcodec.encoders import AudioEncoder, VideoEncoder
1515

1616
from .utils import (
1717
assert_tensor_close_on_at_least,
@@ -564,3 +564,117 @@ def write(self, data):
564564
RuntimeError, match="File like object must implement a seek method"
565565
):
566566
encoder.to_file_like(NoSeekMethod(), format="wav")
567+
568+
569+
class TestVideoEncoder:
570+
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
571+
def test_bad_input_parameterized(self, tmp_path, method):
572+
if method == "to_file":
573+
valid_params = dict(dest=str(tmp_path / "output.mp4"))
574+
elif method == "to_tensor":
575+
valid_params = dict(format="mp4")
576+
elif method == "to_file_like":
577+
valid_params = dict(file_like=io.BytesIO(), format="mp4")
578+
else:
579+
raise ValueError(f"Unknown method: {method}")
580+
581+
with pytest.raises(
582+
ValueError, match="Expected uint8 frames, got frames.dtype = torch.float32"
583+
):
584+
encoder = VideoEncoder(
585+
frames=torch.rand(5, 3, 64, 64),
586+
frame_rate=30,
587+
)
588+
getattr(encoder, method)(**valid_params)
589+
590+
with pytest.raises(
591+
ValueError, match=r"Expected 3D or 4D frames, got frames.shape = torch.Size"
592+
):
593+
encoder = VideoEncoder(
594+
frames=torch.zeros(10),
595+
frame_rate=30,
596+
)
597+
getattr(encoder, method)(**valid_params)
598+
599+
with pytest.raises(
600+
RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2"
601+
):
602+
encoder = VideoEncoder(
603+
frames=torch.zeros((5, 2, 64, 64), dtype=torch.uint8),
604+
frame_rate=30,
605+
)
606+
getattr(encoder, method)(**valid_params)
607+
608+
def test_bad_input(self, tmp_path):
609+
encoder = VideoEncoder(
610+
frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8),
611+
frame_rate=30,
612+
)
613+
614+
with pytest.raises(
615+
RuntimeError,
616+
match=r"Couldn't allocate AVFormatContext. The destination file is ./file.bad_extension, check the desired extension\?",
617+
):
618+
encoder.to_file("./file.bad_extension")
619+
620+
with pytest.raises(
621+
RuntimeError,
622+
match=r"avio_open failed. The destination file is ./bad/path.mp3, make sure it's a valid path\?",
623+
):
624+
encoder.to_file("./bad/path.mp3")
625+
626+
with pytest.raises(
627+
RuntimeError,
628+
match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format",
629+
):
630+
encoder.to_tensor(format="bad_format")
631+
632+
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
633+
def test_contiguity(self, method, tmp_path):
634+
# Ensure that 2 sets of video frames with the same pixel values are encoded
635+
# in the same way, regardless of their memory layout. Here we encode 2 equal
636+
# frame tensors, one is contiguous while the other is non-contiguous.
637+
638+
num_frames, channels, height, width = 5, 3, 64, 64
639+
contiguous_frames = (
640+
(torch.rand(num_frames, channels, height, width) * 255)
641+
.to(torch.uint8)
642+
.contiguous()
643+
)
644+
assert contiguous_frames.is_contiguous()
645+
646+
# Create non-contiguous frames by permuting, calling contiguous to update memory layout,
647+
# then permuting back to the initial order
648+
non_contiguous_frames = (
649+
contiguous_frames.permute(0, 3, 2, 1).contiguous().permute(0, 3, 2, 1)
650+
)
651+
assert non_contiguous_frames.stride() != contiguous_frames.stride()
652+
assert not non_contiguous_frames.is_contiguous()
653+
654+
torch.testing.assert_close(
655+
contiguous_frames, non_contiguous_frames, rtol=0, atol=0
656+
)
657+
658+
def encode_to_tensor(frames):
659+
if method == "to_file":
660+
dest = str(tmp_path / "output.mp4")
661+
VideoEncoder(frames, frame_rate=30).to_file(dest=dest)
662+
with open(dest, "rb") as f:
663+
return torch.frombuffer(f.read(), dtype=torch.uint8)
664+
elif method == "to_tensor":
665+
return VideoEncoder(frames, frame_rate=30).to_tensor(format="mp4")
666+
elif method == "to_file_like":
667+
file_like = io.BytesIO()
668+
VideoEncoder(frames, frame_rate=30).to_file_like(
669+
file_like, format="mp4"
670+
)
671+
return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8)
672+
else:
673+
raise ValueError(f"Unknown method: {method}")
674+
675+
encoded_from_contiguous = encode_to_tensor(contiguous_frames)
676+
encoded_from_non_contiguous = encode_to_tensor(non_contiguous_frames)
677+
678+
torch.testing.assert_close(
679+
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
680+
)

test/test_ops.py

Lines changed: 5 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,68 +1152,6 @@ def test_bad_input(self, tmp_path):
11521152

11531153

11541154
class TestVideoEncoderOps:
1155-
# TODO-VideoEncoder: Test encoding against different memory layouts (ex. test_contiguity)
1156-
# TODO-VideoEncoder: Parametrize test after moving to test_encoders
1157-
def test_bad_input(self, tmp_path):
1158-
output_file = str(tmp_path / ".mp4")
1159-
1160-
with pytest.raises(
1161-
RuntimeError, match="frames must have uint8 dtype, got float"
1162-
):
1163-
encode_video_to_file(
1164-
frames=torch.rand((10, 3, 60, 60), dtype=torch.float),
1165-
frame_rate=10,
1166-
filename=output_file,
1167-
)
1168-
1169-
with pytest.raises(
1170-
RuntimeError, match=r"frames must have 4 dimensions \(N, C, H, W\), got 3"
1171-
):
1172-
encode_video_to_file(
1173-
frames=torch.randint(high=1, size=(3, 60, 60), dtype=torch.uint8),
1174-
frame_rate=10,
1175-
filename=output_file,
1176-
)
1177-
1178-
with pytest.raises(
1179-
RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2"
1180-
):
1181-
encode_video_to_file(
1182-
frames=torch.randint(high=1, size=(10, 2, 60, 60), dtype=torch.uint8),
1183-
frame_rate=10,
1184-
filename=output_file,
1185-
)
1186-
1187-
with pytest.raises(
1188-
RuntimeError,
1189-
match=r"Couldn't allocate AVFormatContext. The destination file is ./file.bad_extension, check the desired extension\?",
1190-
):
1191-
encode_video_to_file(
1192-
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
1193-
frame_rate=10,
1194-
filename="./file.bad_extension",
1195-
)
1196-
1197-
with pytest.raises(
1198-
RuntimeError,
1199-
match=r"avio_open failed. The destination file is ./bad/path.mp3, make sure it's a valid path\?",
1200-
):
1201-
encode_video_to_file(
1202-
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
1203-
frame_rate=10,
1204-
filename="./bad/path.mp3",
1205-
)
1206-
1207-
with pytest.raises(
1208-
RuntimeError,
1209-
match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format",
1210-
):
1211-
encode_video_to_tensor(
1212-
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
1213-
frame_rate=10,
1214-
format="bad_format",
1215-
)
1216-
12171155
def decode(self, source=None) -> torch.Tensor:
12181156
return VideoDecoder(source).get_frames_in_range(start=0, stop=60)
12191157

@@ -1406,7 +1344,9 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
14061344
)
14071345

14081346
def test_to_file_like_custom_file_object(self):
1409-
"""Test with a custom file-like object that implements write and seek."""
1347+
"""Test to_file_like with a custom file-like object that implements write and seek."""
1348+
if get_ffmpeg_major_version() == 6:
1349+
pytest.skip("Skipping round trip test for FFmpeg 6")
14101350

14111351
class CustomFileObject:
14121352
def __init__(self):
@@ -1437,6 +1377,8 @@ def get_encoded_data(self):
14371377

14381378
def test_to_file_like_real_file(self, tmp_path):
14391379
"""Test to_file_like with a real file opened in binary write mode."""
1380+
if get_ffmpeg_major_version() == 6:
1381+
pytest.skip("Skipping round trip test for FFmpeg 6")
14401382
source_frames = self.decode(TEST_SRC_2_720P.path).data
14411383
file_path = tmp_path / "test_file_like.mp4"
14421384

0 commit comments

Comments
 (0)