1+ import json
2+ import os
13import re
24import subprocess
5+ from pathlib import Path
36
47import pytest
58import torch
1619)
1720
1821
22+ @pytest .fixture
23+ def with_ffmpeg_debug_logs ():
24+ # Fixture that sets the ffmpeg logs to DEBUG mode
25+ previous_log_level = os .environ .get ("TORCHCODEC_FFMPEG_LOG_LEVEL" , "QUIET" )
26+ os .environ ["TORCHCODEC_FFMPEG_LOG_LEVEL" ] = "DEBUG"
27+ yield
28+ os .environ ["TORCHCODEC_FFMPEG_LOG_LEVEL" ] = previous_log_level
29+
30+
31+ def validate_frames_properties (* , actual : Path , expected : Path ):
32+ # actual and expected are files containing encoded audio data. We call
33+ # `ffprobe` on both, and assert that the frame properties match (pts,
34+ # duration, etc.)
35+
36+ frames_actual , frames_expected = (
37+ json .loads (
38+ subprocess .run (
39+ [
40+ "ffprobe" ,
41+ "-v" ,
42+ "error" ,
43+ "-hide_banner" ,
44+ "-select_streams" ,
45+ "a:0" ,
46+ "-show_frames" ,
47+ "-of" ,
48+ "json" ,
49+ f"{ f } " ,
50+ ],
51+ check = True ,
52+ capture_output = True ,
53+ text = True ,
54+ ).stdout
55+ )["frames" ]
56+ for f in (actual , expected )
57+ )
58+
59+ # frames_actual and frames_expected are both a list of dicts, each dict
60+ # corresponds to a frame and each key-value pair corresponds to a frame
61+ # property like pts, nb_samples, etc., similar to the AVFrame fields.
62+ assert isinstance (frames_actual , list )
63+ assert all (isinstance (d , dict ) for d in frames_actual )
64+
65+ assert len (frames_actual ) > 3 # arbitrary sanity check
66+ assert len (frames_actual ) == len (frames_expected )
67+
68+ # non-exhaustive list of the props we want to test for:
69+ required_props = (
70+ "pts" ,
71+ "pts_time" ,
72+ "sample_fmt" ,
73+ "nb_samples" ,
74+ "channels" ,
75+ "duration" ,
76+ "duration_time" ,
77+ )
78+
79+ for frame_index , (d_actual , d_expected ) in enumerate (
80+ zip (frames_actual , frames_expected )
81+ ):
82+ if get_ffmpeg_major_version () >= 6 :
83+ assert all (required_prop in d_expected for required_prop in required_props )
84+
85+ for prop in d_expected :
86+ if prop == "pkt_pos" :
87+ # pkt_pos is the position of the packet *in bytes* in its
88+ # stream. We don't always match FFmpeg exactly on this,
89+ # typically on compressed formats like mp3. It's probably
90+ # because we are not writing the exact same headers, or
91+ # something like this. In any case, this doesn't seem to be
92+ # critical.
93+ continue
94+ assert (
95+ d_actual [prop ] == d_expected [prop ]
96+ ), f"\n Comparing: { actual } \n against reference: { expected } ,\n the { prop } property is different at frame { frame_index } :"
97+
98+
1999class TestAudioEncoder :
20100
21101 def decode (self , source ) -> torch .Tensor :
22102 if isinstance (source , TestContainerFile ):
23103 source = str (source .path )
24- return AudioDecoder (source ).get_all_samples (). data
104+ return AudioDecoder (source ).get_all_samples ()
25105
26106 def test_bad_input (self ):
27107 with pytest .raises (ValueError , match = "Expected samples to be a Tensor" ):
@@ -63,12 +143,12 @@ def test_bad_input_parametrized(self, method, tmp_path):
63143 else dict (format = "mp3" )
64144 )
65145
66- decoder = AudioEncoder (self .decode (NASA_AUDIO_MP3 ), sample_rate = 10 )
146+ decoder = AudioEncoder (self .decode (NASA_AUDIO_MP3 ). data , sample_rate = 10 )
67147 with pytest .raises (RuntimeError , match = "invalid sample rate=10" ):
68148 getattr (decoder , method )(** valid_params )
69149
70150 decoder = AudioEncoder (
71- self .decode (NASA_AUDIO_MP3 ), sample_rate = NASA_AUDIO_MP3 .sample_rate
151+ self .decode (NASA_AUDIO_MP3 ). data , sample_rate = NASA_AUDIO_MP3 .sample_rate
72152 )
73153 with pytest .raises (RuntimeError , match = "bit_rate=-1 must be >= 0" ):
74154 getattr (decoder , method )(** valid_params , bit_rate = - 1 )
@@ -81,7 +161,7 @@ def test_bad_input_parametrized(self, method, tmp_path):
81161 getattr (decoder , method )(** valid_params )
82162
83163 decoder = AudioEncoder (
84- self .decode (NASA_AUDIO_MP3 ), sample_rate = NASA_AUDIO_MP3 .sample_rate
164+ self .decode (NASA_AUDIO_MP3 ). data , sample_rate = NASA_AUDIO_MP3 .sample_rate
85165 )
86166 for num_channels in (0 , 3 ):
87167 with pytest .raises (
@@ -101,7 +181,7 @@ def test_round_trip(self, method, format, tmp_path):
101181 pytest .skip ("Swresample with FFmpeg 4 doesn't work on wav files" )
102182
103183 asset = NASA_AUDIO_MP3
104- source_samples = self .decode (asset )
184+ source_samples = self .decode (asset ). data
105185
106186 encoder = AudioEncoder (source_samples , sample_rate = asset .sample_rate )
107187
@@ -116,7 +196,7 @@ def test_round_trip(self, method, format, tmp_path):
116196
117197 rtol , atol = (0 , 1e-4 ) if format == "wav" else (None , None )
118198 torch .testing .assert_close (
119- self .decode (encoded_source ), source_samples , rtol = rtol , atol = atol
199+ self .decode (encoded_source ). data , source_samples , rtol = rtol , atol = atol
120200 )
121201
122202 @pytest .mark .skipif (in_fbcode (), reason = "TODO: enable ffmpeg CLI" )
@@ -125,7 +205,17 @@ def test_round_trip(self, method, format, tmp_path):
125205 @pytest .mark .parametrize ("num_channels" , (None , 1 , 2 ))
126206 @pytest .mark .parametrize ("format" , ("mp3" , "wav" , "flac" ))
127207 @pytest .mark .parametrize ("method" , ("to_file" , "to_tensor" ))
128- def test_against_cli (self , asset , bit_rate , num_channels , format , method , tmp_path ):
208+ def test_against_cli (
209+ self ,
210+ asset ,
211+ bit_rate ,
212+ num_channels ,
213+ format ,
214+ method ,
215+ tmp_path ,
216+ capfd ,
217+ with_ffmpeg_debug_logs ,
218+ ):
129219 # Encodes samples with our encoder and with the FFmpeg CLI, and checks
130220 # that both decoded outputs are equal
131221
@@ -144,14 +234,25 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
144234 check = True ,
145235 )
146236
147- encoder = AudioEncoder (self .decode (asset ), sample_rate = asset .sample_rate )
237+ encoder = AudioEncoder (self .decode (asset ).data , sample_rate = asset .sample_rate )
238+
148239 params = dict (bit_rate = bit_rate , num_channels = num_channels )
149240 if method == "to_file" :
150241 encoded_by_us = tmp_path / f"output.{ format } "
151242 encoder .to_file (dest = str (encoded_by_us ), ** params )
152243 else :
153244 encoded_by_us = encoder .to_tensor (format = format , ** params )
154245
246+ captured = capfd .readouterr ()
247+ if format == "wav" :
248+ assert "Timestamps are unset in a packet" not in captured .err
249+ if format == "mp3" :
250+ assert "Queue input is backward in time" not in captured .err
251+ if format in ("flac" , "wav" ):
252+ assert "Encoder did not produce proper pts" not in captured .err
253+ if format in ("flac" , "mp3" ):
254+ assert "Application provided invalid" not in captured .err
255+
155256 if format == "wav" :
156257 rtol , atol = 0 , 1e-4
157258 elif format == "mp3" and asset is SINE_MONO_S32 and num_channels == 2 :
@@ -162,12 +263,22 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
162263 rtol , atol = 0 , 1e-3
163264 else :
164265 rtol , atol = None , None
266+ samples_by_us = self .decode (encoded_by_us )
267+ samples_by_ffmpeg = self .decode (encoded_by_ffmpeg )
165268 torch .testing .assert_close (
166- self . decode ( encoded_by_ffmpeg ) ,
167- self . decode ( encoded_by_us ) ,
269+ samples_by_us . data ,
270+ samples_by_ffmpeg . data ,
168271 rtol = rtol ,
169272 atol = atol ,
170273 )
274+ assert samples_by_us .pts_seconds == samples_by_ffmpeg .pts_seconds
275+ assert samples_by_us .duration_seconds == samples_by_ffmpeg .duration_seconds
276+ assert samples_by_us .sample_rate == samples_by_ffmpeg .sample_rate
277+
278+ if method == "to_file" :
279+ validate_frames_properties (actual = encoded_by_us , expected = encoded_by_ffmpeg )
280+ else :
281+ assert method == "to_tensor" , "wrong test parametrization!"
171282
172283 @pytest .mark .parametrize ("asset" , (NASA_AUDIO_MP3 , SINE_MONO_S32 ))
173284 @pytest .mark .parametrize ("bit_rate" , (None , 0 , 44_100 , 999_999_999 ))
@@ -179,7 +290,7 @@ def test_to_tensor_against_to_file(
179290 if get_ffmpeg_major_version () == 4 and format == "wav" :
180291 pytest .skip ("Swresample with FFmpeg 4 doesn't work on wav files" )
181292
182- encoder = AudioEncoder (self .decode (asset ), sample_rate = asset .sample_rate )
293+ encoder = AudioEncoder (self .decode (asset ). data , sample_rate = asset .sample_rate )
183294
184295 params = dict (bit_rate = bit_rate , num_channels = num_channels )
185296 encoded_file = tmp_path / f"output.{ format } "
@@ -189,7 +300,7 @@ def test_to_tensor_against_to_file(
189300 )
190301
191302 torch .testing .assert_close (
192- self .decode (encoded_file ), self .decode (encoded_tensor )
303+ self .decode (encoded_file ). data , self .decode (encoded_tensor ). data
193304 )
194305
195306 def test_encode_to_tensor_long_output (self ):
@@ -205,7 +316,7 @@ def test_encode_to_tensor_long_output(self):
205316 INITIAL_TENSOR_SIZE = 10_000_000
206317 assert encoded_tensor .numel () > INITIAL_TENSOR_SIZE
207318
208- torch .testing .assert_close (self .decode (encoded_tensor ), samples )
319+ torch .testing .assert_close (self .decode (encoded_tensor ). data , samples )
209320
210321 def test_contiguity (self ):
211322 # Ensure that 2 waveforms with the same values are encoded in the same
@@ -262,4 +373,4 @@ def test_num_channels(
262373
263374 if num_channels_output is None :
264375 num_channels_output = num_channels_input
265- assert self .decode (encoded_source ).shape [0 ] == num_channels_output
376+ assert self .decode (encoded_source ).data . shape [0 ] == num_channels_output
0 commit comments