1111import torch
1212from torchcodec .decoders import AudioDecoder
1313
14+ from torchcodec .decoders ._video_decoder import VideoDecoder
1415from torchcodec .encoders import AudioEncoder , VideoEncoder
1516
1617from .utils import (
18+ TEST_SRC_2_720P ,
1719 assert_tensor_close_on_at_least ,
1820 get_ffmpeg_major_version ,
1921 get_ffmpeg_minor_version ,
@@ -567,6 +569,16 @@ def write(self, data):
567569
568570
569571class TestVideoEncoder :
572+
573+ def decode (self , source = None ) -> torch .Tensor :
574+ return VideoDecoder (source ).get_frames_in_range (start = 0 , stop = 60 )
575+
576+ def save_image (self , a , b , name ):
577+ from torchvision .io import write_png
578+ from torchvision .utils import make_grid
579+ image = make_grid (torch .stack ([a , b ]), nrow = 2 ).cpu ()
580+ write_png (image , f"{ name } .png" )
581+
570582 @pytest .mark .parametrize ("method" , ("to_file" , "to_tensor" , "to_file_like" ))
571583 def test_bad_input_parameterized (self , tmp_path , method ):
572584 if method == "to_file" :
@@ -683,8 +695,11 @@ def encode_to_tensor(frames):
683695 )
684696 def test_device_video_encoder (self , method , device , tmp_path ):
685697 # Test that encoding works on CPU and CUDA devices
686- num_frames , channels , height , width = 5 , 3 , 64 , 64
687- frames = (torch .rand (num_frames , channels , height , width ) * 255 ).to (torch .uint8 )
698+ # num_frames, channels, height, width = 5, 3, 1024, 1024
699+ # frames = (torch.rand(num_frames, channels, height, width) * 255).to(torch.uint8)
700+
701+ asset = TEST_SRC_2_720P
702+ frames = self .decode (asset .path ).data
688703
689704 encoder = VideoEncoder (frames , frame_rate = 30 , device = device )
690705
@@ -693,15 +708,30 @@ def test_device_video_encoder(self, method, device, tmp_path):
693708 encoder .to_file (dest = dest )
694709 # Verify file was created
695710 assert Path (dest ).exists ()
711+ self .save_image (
712+ frames [0 ],
713+ self .decode (Path (dest )).data [0 ],
714+ name = f"{ device } _to_file" ,
715+ )
696716 elif method == "to_tensor" :
697717 encoded = encoder .to_tensor (format = "mp4" )
698718 assert encoded .dtype == torch .uint8
699719 assert encoded .ndim == 1
700720 assert encoded .numel () > 0
721+ self .save_image (
722+ frames [0 ],
723+ self .decode (encoded ).data [0 ],
724+ name = f"{ device } _to_tensor" ,
725+ )
701726 elif method == "to_file_like" :
702727 file_like = io .BytesIO ()
703728 encoder .to_file_like (file_like , format = "mp4" )
704729 encoded_bytes = file_like .getvalue ()
705730 assert len (encoded_bytes ) > 0
731+ self .save_image (
732+ frames [0 ],
733+ self .decode (encoded_bytes ).data [0 ],
734+ name = f"{ device } _to_file_like" ,
735+ )
706736 else :
707737 raise ValueError (f"Unknown method: { method } " )
0 commit comments