@@ -1197,22 +1197,52 @@ def test_pts_to_dts_fallback(self, seek_mode):
11971197 torch .testing .assert_close (decoder [0 ], decoder [10 ])
11981198
11991199 @needs_cuda
1200- @pytest .mark .parametrize ("asset" , (H264_10BITS , H265_10BITS ))
1201- def test_10bit_videos_cuda (self , asset ):
1200+ def test_10bit_videos_cuda (self ):
12021201 # Assert that we raise proper error on different kinds of 10bit videos.
12031202
12041203 # TODO we should investigate how to support 10bit videos on GPU.
12051204 # See https://github.com/pytorch/torchcodec/issues/776
12061205
1207- decoder = VideoDecoder ( asset . path , device = "cuda" )
1206+ asset = H265_10BITS
12081207
1209- if asset is H265_10BITS :
1210- match = "The AVFrame is p010le, but we expected AV_PIX_FMT_NV12."
1211- else :
1212- match = "Expected format to be AV_PIX_FMT_CUDA, got yuv420p10le."
1213- with pytest . raises ( RuntimeError , match = match ):
1208+ decoder = VideoDecoder ( asset . path , device = "cuda" )
1209+ with pytest . raises (
1210+ RuntimeError ,
1211+ match = "The AVFrame is p010le, but we expected AV_PIX_FMT_NV12." ,
1212+ ):
12141213 decoder .get_frame_at (0 )
12151214
1215+ @needs_cuda
1216+ def test_10bit_gpu_fallsback_to_cpu (self ):
1217+ # Test for 10-bit videos that aren't supported by NVDEC: we decode and
1218+ # do the color conversion on the CPU.
1219+ # Here we just assert that the GPU results are the same as the CPU
1220+ # results.
1221+ # TODO see other TODO below in test_10bit_videos_cpu: we should validate
1222+ # the frames against a reference.
1223+
1224+ # We know from previous tests that the H264_10BITS video isn't supported
1225+ # by NVDEC, so NVDEC decodes it on the CPU.
1226+ asset = H264_10BITS
1227+
1228+ decoder_gpu = VideoDecoder (asset .path , device = "cuda" )
1229+ decoder_cpu = VideoDecoder (asset .path )
1230+
1231+ frame_indices = [0 , 10 , 20 , 5 ]
1232+ for frame_index in frame_indices :
1233+ frame_gpu = decoder_gpu .get_frame_at (frame_index ).data
1234+ assert frame_gpu .device .type == "cuda"
1235+ frame_cpu = decoder_cpu .get_frame_at (frame_index ).data
1236+ assert_frames_equal (frame_gpu .cpu (), frame_cpu )
1237+
1238+ # We also check a batch API just to be on the safe side, making sure the
1239+ # pre-allocated tensor is passed down correctly to the CPU
1240+ # implementation.
1241+ frames_gpu = decoder_gpu .get_frames_at (frame_indices ).data
1242+ assert frames_gpu .device .type == "cuda"
1243+ frames_cpu = decoder_cpu .get_frames_at (frame_indices ).data
1244+ assert_frames_equal (frames_gpu .cpu (), frames_cpu )
1245+
12161246 @pytest .mark .parametrize ("asset" , (H264_10BITS , H265_10BITS ))
12171247 def test_10bit_videos_cpu (self , asset ):
12181248 # This just validates that we can decode 10-bit videos on CPU.
0 commit comments