@@ -275,7 +275,32 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
275275 }
276276
277277 torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex (device_);
278- nppCtx_->hStream = at::cuda::getCurrentCUDAStream (deviceIndex).stream ();
278+
279+ // Create a CUDA event and attach it to the AVFrame's CUDA stream. That's the
280+ // NVDEC stream, i.e. the CUDA stream that the frame was decoded on.
281+ // We will be waiting for this event to complete before calling the NPP
282+ // functions, to ensure NVDEC has finished decoding the frame before running
283+ // the NPP color-conversion.
284+ // Note that our code is generic and assumes that the NVDEC's stream can be
285+ // arbitrary, but unfortunately we know it's hardcoded to be the default
286+ // stream by FFmpeg:
287+ // https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
288+ TORCH_CHECK (
289+ hwFramesCtx->device_ctx != nullptr ,
290+ " The AVFrame's hw_frames_ctx does not have a device_ctx. " );
291+ auto cudaDeviceCtx =
292+ static_cast <AVCUDADeviceContext*>(hwFramesCtx->device_ctx ->hwctx );
293+ at::cuda::CUDAEvent nvdecDoneEvent;
294+ at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
295+ c10::cuda::getStreamFromExternal (cudaDeviceCtx->stream , deviceIndex);
296+ nvdecDoneEvent.record (nvdecStream);
297+
298+ // Don't start NPP work before NVDEC is done decoding the frame!
299+ at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream (deviceIndex);
300+ nvdecDoneEvent.block (nppStream);
301+
302+ // Create the NPP context if we haven't yet.
303+ nppCtx_->hStream = nppStream.stream ();
279304 cudaError_t err =
280305 cudaStreamGetFlags (nppCtx_->hStream , &nppCtx_->nStreamFlags );
281306 TORCH_CHECK (
0 commit comments