@@ -161,6 +161,44 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
161161 device, nonNegativeDeviceIndex, type);
162162#endif
163163}
164+
165+ NppStreamContext createNppStreamContext (int deviceIndex) {
166+ // From 12.9, NPP recommends using a user-created NppStreamContext and using
167+ // the `_Ctx()` calls:
168+ // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1
169+ // And the nppGetStreamContext() helper is deprecated. We are explicitly
170+ // supposed to create the NppStreamContext manually from the CUDA device
171+ // properties:
172+ // https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72
173+
174+ NppStreamContext nppCtx{};
175+ cudaDeviceProp prop{};
176+ cudaError_t err = cudaGetDeviceProperties (&prop, deviceIndex);
177+ TORCH_CHECK (
178+ err == cudaSuccess,
179+ " cudaGetDeviceProperties failed: " ,
180+ cudaGetErrorString (err));
181+
182+ nppCtx.nCudaDeviceId = deviceIndex;
183+ nppCtx.nMultiProcessorCount = prop.multiProcessorCount ;
184+ nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor ;
185+ nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock ;
186+ nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock ;
187+ nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major ;
188+ nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor ;
189+
190+ // TODO when implementing the cache logic, move these out. See other TODO
191+ // below.
192+ nppCtx.hStream = at::cuda::getCurrentCUDAStream (deviceIndex).stream ();
193+ err = cudaStreamGetFlags (nppCtx.hStream , &nppCtx.nStreamFlags );
194+ TORCH_CHECK (
195+ err == cudaSuccess,
196+ " cudaStreamGetFlags failed: " ,
197+ cudaGetErrorString (err));
198+
199+ return nppCtx;
200+ }
201+
164202} // namespace
165203
166204CudaDeviceInterface::CudaDeviceInterface (const torch::Device& device)
@@ -265,37 +303,37 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
265303 dst = allocateEmptyHWCTensor (height, width, device_);
266304 }
267305
268- // Use the user-requested GPU for running the NPP kernel.
269- c10::cuda::CUDAGuard deviceGuard (device_);
306+ // TODO cache the NppStreamContext! It currently gets re-recated for every
307+ // single frame. The cache should be per-device, similar to the existing
308+ // hw_device_ctx cache. When implementing the cache logic, the
309+ // NppStreamContext hStream and nStreamFlags should not be part of the cache
310+ // because they may change across calls.
311+ NppStreamContext nppCtx = createNppStreamContext (
312+ static_cast <int >(getFFMPEGCompatibleDeviceIndex (device_)));
270313
271314 NppiSize oSizeROI = {width, height};
272315 Npp8u* input[2 ] = {avFrame->data [0 ], avFrame->data [1 ]};
273316
274317 NppStatus status;
318+
275319 if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
276- status = nppiNV12ToRGB_709CSC_8u_P2C3R (
320+ status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx (
277321 input,
278322 avFrame->linesize [0 ],
279323 static_cast <Npp8u*>(dst.data_ptr ()),
280324 dst.stride (0 ),
281- oSizeROI);
325+ oSizeROI,
326+ nppCtx);
282327 } else {
283- status = nppiNV12ToRGB_8u_P2C3R (
328+ status = nppiNV12ToRGB_8u_P2C3R_Ctx (
284329 input,
285330 avFrame->linesize [0 ],
286331 static_cast <Npp8u*>(dst.data_ptr ()),
287332 dst.stride (0 ),
288- oSizeROI);
333+ oSizeROI,
334+ nppCtx);
289335 }
290336 TORCH_CHECK (status == NPP_SUCCESS, " Failed to convert NV12 frame." );
291-
292- // Make the pytorch stream wait for the npp kernel to finish before using the
293- // output.
294- at::cuda::CUDAEvent nppDoneEvent;
295- at::cuda::CUDAStream nppStreamWrapper =
296- c10::cuda::getStreamFromExternal (nppGetStream (), device_.index ());
297- nppDoneEvent.record (nppStreamWrapper);
298- nppDoneEvent.block (at::cuda::getCurrentCUDAStream ());
299337}
300338
301339// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
0 commit comments