@@ -199,12 +199,121 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
199199 return ;
200200}
201201
202+ std::unique_ptr<FiltersContext> CudaDeviceInterface::initializeFiltersContext (
203+ const VideoStreamOptions& videoStreamOptions,
204+ const UniqueAVFrame& avFrame,
205+ const AVRational& timeBase) {
206+ // We need FFmpeg filters to handle those conversion cases which are not
207+ // directly implemented in CUDA or CPU device interface (in case of a
208+ // fallback).
209+ enum AVPixelFormat frameFormat =
210+ static_cast <enum AVPixelFormat>(avFrame->format );
211+
212+ // Input frame is on CPU, we will just pass it to CPU device interface, so
213+ // skipping filters context as CPU device interface will handle everythong for
214+ // us.
215+ if (avFrame->format != AV_PIX_FMT_CUDA) {
216+ return nullptr ;
217+ }
218+
219+ TORCH_CHECK (
220+ avFrame->hw_frames_ctx != nullptr ,
221+ " The AVFrame does not have a hw_frames_ctx. "
222+ " That's unexpected, please report this to the TorchCodec repo." );
223+
224+ auto hwFramesCtx =
225+ reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
226+ AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
227+
228+ // NV12 conversion is implemented directly with NPP, no need for filters.
229+ if (actualFormat == AV_PIX_FMT_NV12) {
230+ return nullptr ;
231+ }
232+
233+ auto frameDims =
234+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
235+ int height = frameDims.height ;
236+ int width = frameDims.width ;
237+
238+ AVPixelFormat outputFormat;
239+ std::stringstream filters;
240+
241+ unsigned version_int = avfilter_version ();
242+ if (version_int < AV_VERSION_INT (8 , 0 , 103 )) {
243+ // Color conversion support ('format=' option) was added to scale_cuda from
244+ // n5.0. With the earlier version of ffmpeg we have no choice but use CPU
245+ // filters. See:
246+ // https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95
247+ outputFormat = AV_PIX_FMT_RGB24;
248+
249+ filters << " hwdownload,format=" << av_pix_fmt_desc_get (actualFormat)->name ;
250+ filters << " ,scale=" << width << " :" << height;
251+ filters << " :sws_flags=bilinear" ;
252+ } else {
253+ // Actual output color format will be set via filter options
254+ outputFormat = AV_PIX_FMT_CUDA;
255+
256+ filters << " scale_cuda=" << width << " :" << height;
257+ filters << " :format=nv12:interp_algo=bilinear" ;
258+ }
259+
260+ return std::make_unique<FiltersContext>(
261+ avFrame->width ,
262+ avFrame->height ,
263+ frameFormat,
264+ avFrame->sample_aspect_ratio ,
265+ width,
266+ height,
267+ outputFormat,
268+ filters.str (),
269+ timeBase,
270+ av_buffer_ref (avFrame->hw_frames_ctx ));
271+ }
272+
202273void CudaDeviceInterface::convertAVFrameToFrameOutput (
203274 const VideoStreamOptions& videoStreamOptions,
204275 [[maybe_unused]] const AVRational& timeBase,
205- UniqueAVFrame& avFrame ,
276+ UniqueAVFrame& avInputFrame ,
206277 FrameOutput& frameOutput,
207278 std::optional<torch::Tensor> preAllocatedOutputTensor) {
279+ std::unique_ptr<FiltersContext> newFiltersContext =
280+ initializeFiltersContext (videoStreamOptions, avInputFrame, timeBase);
281+ UniqueAVFrame avFilteredFrame;
282+ if (newFiltersContext) {
283+ // We need to compare the current filter context with our previous filter
284+ // context. If they are different, then we need to re-create a filter
285+ // graph. We create a filter graph late so that we don't have to depend
286+ // on the unreliable metadata in the header. And we sometimes re-create
287+ // it because it's possible for frame resolution to change mid-stream.
288+ // Finally, we want to reuse the filter graph as much as possible for
289+ // performance reasons.
290+ if (!filterGraph_ || filtersContext_ != newFiltersContext) {
291+ filterGraph_ =
292+ std::make_unique<FilterGraph>(*newFiltersContext, videoStreamOptions);
293+ filtersContext_ = std::move (newFiltersContext);
294+ }
295+ avFilteredFrame = filterGraph_->convert (avInputFrame);
296+
297+ // If this check fails it means the frame wasn't
298+ // reshaped to its expected dimensions by filtergraph.
299+ TORCH_CHECK (
300+ (avFilteredFrame->width == filtersContext_->outputWidth ) &&
301+ (avFilteredFrame->height == filtersContext_->outputHeight ),
302+ " Expected frame from filter graph of " ,
303+ filtersContext_->outputWidth ,
304+ " x" ,
305+ filtersContext_->outputHeight ,
306+ " , got " ,
307+ avFilteredFrame->width ,
308+ " x" ,
309+ avFilteredFrame->height );
310+ }
311+
312+ UniqueAVFrame& avFrame = (avFilteredFrame) ? avFilteredFrame : avInputFrame;
313+
314+ // The filtered frame might be on CPU if CPU fallback has happenned on filter
315+ // graph level. For example, that's how we handle color format conversion
316+ // on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet.
208317 if (avFrame->format != AV_PIX_FMT_CUDA) {
209318 // The frame's format is AV_PIX_FMT_CUDA if and only if its content is on
210319 // the GPU. In this branch, the frame is on the CPU: this is what NVDEC
@@ -232,8 +341,6 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
232341 // Above we checked that the AVFrame was on GPU, but that's not enough, we
233342 // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
234343 // because this is what the NPP color conversion routines expect.
235- // TODO: we should investigate how to can perform color conversion for
236- // non-8bit videos. This is supported on CPU.
237344 TORCH_CHECK (
238345 avFrame->hw_frames_ctx != nullptr ,
239346 " The AVFrame does not have a hw_frames_ctx. "
@@ -242,16 +349,14 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
242349 auto hwFramesCtx =
243350 reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
244351 AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
352+
245353 TORCH_CHECK (
246354 actualFormat == AV_PIX_FMT_NV12,
247355 " The AVFrame is " ,
248356 (av_get_pix_fmt_name (actualFormat) ? av_get_pix_fmt_name (actualFormat)
249357 : " unknown" ),
250- " , but we expected AV_PIX_FMT_NV12. This typically happens when "
251- " the video isn't 8bit, which is not supported on CUDA at the moment. "
252- " Try using the CPU device instead. "
253- " If the video is 10bit, we are tracking 10bit support in "
254- " https://github.com/pytorch/torchcodec/issues/776" );
358+ " , but we expected AV_PIX_FMT_NV12. "
359+ " That's unexpected, please report this to the TorchCodec repo." );
255360
256361 auto frameDims =
257362 getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
0 commit comments