@@ -955,3 +955,126 @@ def test_metadata(self, asset):
955955 )
956956 assert decoder .metadata .sample_rate == asset .sample_rate
957957 assert decoder .metadata .num_channels == asset .num_channels
958+
959+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
960+ def test_error (self , asset ):
961+ decoder = AudioDecoder (asset .path )
962+
963+ with pytest .raises (ValueError , match = "Invalid start seconds" ):
964+ decoder .get_samples_played_in_range (start_seconds = - 1300 )
965+
966+ with pytest .raises (ValueError , match = "Invalid start seconds" ):
967+ decoder .get_samples_played_in_range (start_seconds = 9999 )
968+
969+ with pytest .raises (ValueError , match = "Invalid start seconds" ):
970+ decoder .get_samples_played_in_range (start_seconds = 3 , stop_seconds = 2 )
971+
972+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
973+ @pytest .mark .parametrize ("stop_seconds" , (None , "duration" , 99999999 ))
974+ def test_get_all_samples (self , asset , stop_seconds ):
975+ decoder = AudioDecoder (asset .path )
976+
977+ if stop_seconds == "duration" :
978+ stop_seconds = asset .duration_seconds
979+
980+ samples = decoder .get_samples_played_in_range (
981+ start_seconds = 0 , stop_seconds = stop_seconds
982+ )
983+
984+ reference_frames = asset .get_frame_data_by_range (
985+ start = 0 , stop = asset .get_frame_index (pts_seconds = asset .duration_seconds ) + 1
986+ )
987+
988+ torch .testing .assert_close (samples .data , reference_frames )
989+ assert samples .sample_rate == asset .sample_rate
990+
991+ # TODO there's a bug with NASA_AUDIO_MP3: https://github.com/pytorch/torchcodec/issues/553
992+ expected_pts = (
993+ 0.072
994+ if asset is NASA_AUDIO_MP3
995+ else asset .get_frame_info (idx = 0 ).pts_seconds
996+ )
997+ assert samples .pts_seconds == expected_pts
998+
999+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
1000+ def test_at_frame_boundaries (self , asset ):
1001+ decoder = AudioDecoder (asset .path )
1002+
1003+ start_frame_index , stop_frame_index = 10 , 40
1004+ start_seconds = asset .get_frame_info (start_frame_index ).pts_seconds
1005+ stop_seconds = asset .get_frame_info (stop_frame_index ).pts_seconds
1006+
1007+ samples = decoder .get_samples_played_in_range (
1008+ start_seconds = start_seconds , stop_seconds = stop_seconds
1009+ )
1010+
1011+ reference_frames = asset .get_frame_data_by_range (
1012+ start = start_frame_index , stop = stop_frame_index
1013+ )
1014+
1015+ assert samples .pts_seconds == start_seconds
1016+ num_samples = samples .data .shape [1 ]
1017+ assert (
1018+ num_samples
1019+ == reference_frames .shape [1 ]
1020+ == (stop_seconds - start_seconds ) * decoder .metadata .sample_rate
1021+ )
1022+ torch .testing .assert_close (samples .data , reference_frames )
1023+ assert samples .sample_rate == asset .sample_rate
1024+
1025+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
1026+ def test_not_at_frame_boundaries (self , asset ):
1027+ decoder = AudioDecoder (asset .path )
1028+
1029+ start_frame_index , stop_frame_index = 10 , 40
1030+ start_frame_info = asset .get_frame_info (start_frame_index )
1031+ stop_frame_info = asset .get_frame_info (stop_frame_index )
1032+ start_seconds = start_frame_info .pts_seconds + (
1033+ start_frame_info .duration_seconds / 2
1034+ )
1035+ stop_seconds = stop_frame_info .pts_seconds + (
1036+ stop_frame_info .duration_seconds / 2
1037+ )
1038+ samples = decoder .get_samples_played_in_range (
1039+ start_seconds = start_seconds , stop_seconds = stop_seconds
1040+ )
1041+
1042+ reference_frames = asset .get_frame_data_by_range (
1043+ start = start_frame_index , stop = stop_frame_index + 1
1044+ )
1045+
1046+ assert samples .pts_seconds == start_seconds
1047+ num_samples = samples .data .shape [1 ]
1048+ assert num_samples < reference_frames .shape [1 ]
1049+ assert (
1050+ num_samples == (stop_seconds - start_seconds ) * decoder .metadata .sample_rate
1051+ )
1052+ assert samples .sample_rate == asset .sample_rate
1053+
1054+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
1055+ def test_start_equals_stop (self , asset ):
1056+ decoder = AudioDecoder (asset .path )
1057+ samples = decoder .get_samples_played_in_range (start_seconds = 3 , stop_seconds = 3 )
1058+ assert samples .data .shape == (0 , 0 )
1059+
1060+ def test_frame_start_is_not_zero (self ):
1061+ # For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.072 [1].
1062+ # So if we request start = 0.05, we shouldn't be truncating anything.
1063+ #
1064+ # [1] well, really it's at 0.138125, not 0.072 (see
1065+ # https://github.com/pytorch/torchcodec/issues/553), but for the purpose
1066+ # of this test it doesn't matter.
1067+
1068+ asset = NASA_AUDIO_MP3
1069+ start_seconds = 0.05 # this is less than the first frame's pts
1070+ stop_frame_index = 10
1071+ stop_seconds = asset .get_frame_info (stop_frame_index ).pts_seconds
1072+
1073+ decoder = AudioDecoder (asset .path )
1074+
1075+ samples = decoder .get_samples_played_in_range (
1076+ start_seconds = start_seconds , stop_seconds = stop_seconds
1077+ )
1078+
1079+ reference_frames = asset .get_frame_data_by_range (start = 0 , stop = stop_frame_index )
1080+ torch .testing .assert_close (samples .data , reference_frames )
0 commit comments