@@ -199,12 +199,127 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
199199 return ;
200200}
201201
202+ std::unique_ptr<FiltersContext> CudaDeviceInterface::initializeFiltersContext (
203+ const VideoStreamOptions& videoStreamOptions,
204+ const UniqueAVFrame& avFrame,
205+ const AVRational& timeBase) {
206+ // We need FFmpeg filters to handle those conversion cases which are not
207+ // directly implemented in CUDA or CPU device interface (in case of a
208+ // fallback).
209+ enum AVPixelFormat frameFormat =
210+ static_cast <enum AVPixelFormat>(avFrame->format );
211+
212+ // Input frame is on CPU, we will just pass it to CPU device interface, so
213+ // skipping filters context as CPU device interface will handle everythong for
214+ // us.
215+ if (avFrame->format != AV_PIX_FMT_CUDA) {
216+ return nullptr ;
217+ }
218+
219+ TORCH_CHECK (
220+ avFrame->hw_frames_ctx != nullptr ,
221+ " The AVFrame does not have a hw_frames_ctx. "
222+ " That's unexpected, please report this to the TorchCodec repo." );
223+
224+ auto hwFramesCtx =
225+ reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
226+ AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
227+
228+ // NV12 conversion is implemented directly with NPP, no need for filters.
229+ if (actualFormat == AV_PIX_FMT_NV12) {
230+ return nullptr ;
231+ }
232+
233+ auto frameDims =
234+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
235+ int height = frameDims.height ;
236+ int width = frameDims.width ;
237+
238+ AVPixelFormat outputFormat;
239+ std::stringstream filters;
240+
241+ unsigned version_int = avfilter_version ();
242+ if (version_int < AV_VERSION_INT (8 , 0 , 103 )) {
243+ // Color conversion support ('format=' option) was added to scale_cuda from
244+ // n5.0. With the earlier version of ffmpeg we have no choice but use CPU
245+ // filters. See:
246+ // https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95
247+ outputFormat = AV_PIX_FMT_RGB24;
248+
249+ auto actualFormatName = av_get_pix_fmt_name (actualFormat);
250+ TORCH_CHECK (
251+ actualFormatName != nullptr ,
252+ " The actual format of a frame is unknown to FFmpeg. "
253+ " That's unexpected, please report this to the TorchCodec repo." );
254+
255+ filters << " hwdownload,format=" << actualFormatName;
256+ filters << " ,scale=" << width << " :" << height;
257+ filters << " :sws_flags=bilinear" ;
258+ } else {
259+ // Actual output color format will be set via filter options
260+ outputFormat = AV_PIX_FMT_CUDA;
261+
262+ filters << " scale_cuda=" << width << " :" << height;
263+ filters << " :format=nv12:interp_algo=bilinear" ;
264+ }
265+
266+ return std::make_unique<FiltersContext>(
267+ avFrame->width ,
268+ avFrame->height ,
269+ frameFormat,
270+ avFrame->sample_aspect_ratio ,
271+ width,
272+ height,
273+ outputFormat,
274+ filters.str (),
275+ timeBase,
276+ av_buffer_ref (avFrame->hw_frames_ctx ));
277+ }
278+
202279void CudaDeviceInterface::convertAVFrameToFrameOutput (
203280 const VideoStreamOptions& videoStreamOptions,
204281 [[maybe_unused]] const AVRational& timeBase,
205- UniqueAVFrame& avFrame ,
282+ UniqueAVFrame& avInputFrame ,
206283 FrameOutput& frameOutput,
207284 std::optional<torch::Tensor> preAllocatedOutputTensor) {
285+ std::unique_ptr<FiltersContext> newFiltersContext =
286+ initializeFiltersContext (videoStreamOptions, avInputFrame, timeBase);
287+ UniqueAVFrame avFilteredFrame;
288+ if (newFiltersContext) {
289+ // We need to compare the current filter context with our previous filter
290+ // context. If they are different, then we need to re-create a filter
291+ // graph. We create a filter graph late so that we don't have to depend
292+ // on the unreliable metadata in the header. And we sometimes re-create
293+ // it because it's possible for frame resolution to change mid-stream.
294+ // Finally, we want to reuse the filter graph as much as possible for
295+ // performance reasons.
296+ if (!filterGraph_ || *filtersContext_ != *newFiltersContext) {
297+ filterGraph_ =
298+ std::make_unique<FilterGraph>(*newFiltersContext, videoStreamOptions);
299+ filtersContext_ = std::move (newFiltersContext);
300+ }
301+ avFilteredFrame = filterGraph_->convert (avInputFrame);
302+
303+ // If this check fails it means the frame wasn't
304+ // reshaped to its expected dimensions by filtergraph.
305+ TORCH_CHECK (
306+ (avFilteredFrame->width == filtersContext_->outputWidth ) &&
307+ (avFilteredFrame->height == filtersContext_->outputHeight ),
308+ " Expected frame from filter graph of " ,
309+ filtersContext_->outputWidth ,
310+ " x" ,
311+ filtersContext_->outputHeight ,
312+ " , got " ,
313+ avFilteredFrame->width ,
314+ " x" ,
315+ avFilteredFrame->height );
316+ }
317+
318+ UniqueAVFrame& avFrame = (avFilteredFrame) ? avFilteredFrame : avInputFrame;
319+
320+ // The filtered frame might be on CPU if CPU fallback has happenned on filter
321+ // graph level. For example, that's how we handle color format conversion
322+ // on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet.
208323 if (avFrame->format != AV_PIX_FMT_CUDA) {
209324 // The frame's format is AV_PIX_FMT_CUDA if and only if its content is on
210325 // the GPU. In this branch, the frame is on the CPU: this is what NVDEC
@@ -232,8 +347,6 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
232347 // Above we checked that the AVFrame was on GPU, but that's not enough, we
233348 // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
234349 // because this is what the NPP color conversion routines expect.
235- // TODO: we should investigate how to can perform color conversion for
236- // non-8bit videos. This is supported on CPU.
237350 TORCH_CHECK (
238351 avFrame->hw_frames_ctx != nullptr ,
239352 " The AVFrame does not have a hw_frames_ctx. "
@@ -242,16 +355,14 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
242355 auto hwFramesCtx =
243356 reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
244357 AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
358+
245359 TORCH_CHECK (
246360 actualFormat == AV_PIX_FMT_NV12,
247361 " The AVFrame is " ,
248362 (av_get_pix_fmt_name (actualFormat) ? av_get_pix_fmt_name (actualFormat)
249363 : " unknown" ),
250- " , but we expected AV_PIX_FMT_NV12. This typically happens when "
251- " the video isn't 8bit, which is not supported on CUDA at the moment. "
252- " Try using the CPU device instead. "
253- " If the video is 10bit, we are tracking 10bit support in "
254- " https://github.com/pytorch/torchcodec/issues/776" );
364+ " , but we expected AV_PIX_FMT_NV12. "
365+ " That's unexpected, please report this to the TorchCodec repo." );
255366
256367 auto frameDims =
257368 getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
0 commit comments