44#include < torch/types.h>
55#include < mutex>
66
7+ #include " src/torchcodec/_core/Cache.h"
78#include " src/torchcodec/_core/CudaDeviceInterface.h"
89#include " src/torchcodec/_core/FFMPEGCommon.h"
910
@@ -44,49 +45,11 @@ const int MAX_CUDA_GPUS = 128;
4445// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
4546// Set to a positive number to have a cache of that size.
4647const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1 ;
47- std::vector<AVBufferRef*> g_cached_hw_device_ctxs[MAX_CUDA_GPUS];
48- std::mutex g_cached_hw_device_mutexes[MAX_CUDA_GPUS];
49-
50- torch::DeviceIndex getFFMPEGCompatibleDeviceIndex (const torch::Device& device) {
51- torch::DeviceIndex deviceIndex = device.index ();
52- deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0 );
53- TORCH_CHECK (deviceIndex >= 0 , " Device index out of range" );
54- // FFMPEG cannot handle negative device indices.
55- // For single GPU- machines libtorch returns -1 for the device index. So for
56- // that case we set the device index to 0.
57- // TODO: Double check if this works for multi-GPU machines correctly.
58- return deviceIndex;
59- }
60-
61- void addToCacheIfCacheHasCapacity (
62- const torch::Device& device,
63- AVBufferRef* hwContext) {
64- torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex (device);
65- if (static_cast <int >(deviceIndex) >= MAX_CUDA_GPUS) {
66- return ;
67- }
68- std::scoped_lock lock (g_cached_hw_device_mutexes[deviceIndex]);
69- if (MAX_CONTEXTS_PER_GPU_IN_CACHE >= 0 &&
70- g_cached_hw_device_ctxs[deviceIndex].size () >=
71- MAX_CONTEXTS_PER_GPU_IN_CACHE) {
72- return ;
73- }
74- g_cached_hw_device_ctxs[deviceIndex].push_back (av_buffer_ref (hwContext));
75- }
76-
77- AVBufferRef* getFromCache (const torch::Device& device) {
78- torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex (device);
79- if (static_cast <int >(deviceIndex) >= MAX_CUDA_GPUS) {
80- return nullptr ;
81- }
82- std::scoped_lock lock (g_cached_hw_device_mutexes[deviceIndex]);
83- if (g_cached_hw_device_ctxs[deviceIndex].size () > 0 ) {
84- AVBufferRef* hw_device_ctx = g_cached_hw_device_ctxs[deviceIndex].back ();
85- g_cached_hw_device_ctxs[deviceIndex].pop_back ();
86- return hw_device_ctx;
87- }
88- return nullptr ;
89- }
48+ PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void , av_buffer_unref>>
49+ g_cached_hw_device_ctxs (MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);
50+ PerGpuCache<NppStreamContext> g_cached_npp_ctxs (
51+ MAX_CUDA_GPUS,
52+ MAX_CONTEXTS_PER_GPU_IN_CACHE);
9053
9154#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
9255
@@ -143,14 +106,13 @@ AVBufferRef* getFFMPEGContextFromNewCudaContext(
143106
144107#endif
145108
146- AVBufferRef* getCudaContext (const torch::Device& device) {
109+ UniqueAVBufferRef getCudaContext (const torch::Device& device) {
147110 enum AVHWDeviceType type = av_hwdevice_find_type_by_name (" cuda" );
148111 TORCH_CHECK (type != AV_HWDEVICE_TYPE_NONE, " Failed to find cuda device" );
149- torch::DeviceIndex nonNegativeDeviceIndex =
150- getFFMPEGCompatibleDeviceIndex (device);
112+ torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex (device);
151113
152- AVBufferRef* hw_device_ctx = getFromCache (device);
153- if (hw_device_ctx != nullptr ) {
114+ UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs. get (device);
115+ if (hw_device_ctx) {
154116 return hw_device_ctx;
155117 }
156118
@@ -161,15 +123,23 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
161123 // 58.26.100 of avutil.
162124 // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
163125#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
164- return getFFMPEGContextFromExistingCudaContext (
165- device, nonNegativeDeviceIndex, type);
126+ return UniqueAVBufferRef ( getFFMPEGContextFromExistingCudaContext (
127+ device, nonNegativeDeviceIndex, type)) ;
166128#else
167- return getFFMPEGContextFromNewCudaContext (
168- device, nonNegativeDeviceIndex, type);
129+ return UniqueAVBufferRef (
130+ getFFMPEGContextFromNewCudaContext ( device, nonNegativeDeviceIndex, type) );
169131#endif
170132}
171133
172- NppStreamContext createNppStreamContext (int deviceIndex) {
134+ std::unique_ptr<NppStreamContext> getNppStreamContext (
135+ const torch::Device& device) {
136+ torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex (device);
137+
138+ std::unique_ptr<NppStreamContext> nppCtx = g_cached_npp_ctxs.get (device);
139+ if (nppCtx) {
140+ return nppCtx;
141+ }
142+
173143 // From 12.9, NPP recommends using a user-created NppStreamContext and using
174144 // the `_Ctx()` calls:
175145 // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1
@@ -178,30 +148,21 @@ NppStreamContext createNppStreamContext(int deviceIndex) {
178148 // properties:
179149 // https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72
180150
181- NppStreamContext nppCtx{} ;
151+ nppCtx = std::make_unique<NppStreamContext>() ;
182152 cudaDeviceProp prop{};
183- cudaError_t err = cudaGetDeviceProperties (&prop, deviceIndex );
153+ cudaError_t err = cudaGetDeviceProperties (&prop, nonNegativeDeviceIndex );
184154 TORCH_CHECK (
185155 err == cudaSuccess,
186156 " cudaGetDeviceProperties failed: " ,
187157 cudaGetErrorString (err));
188158
189- nppCtx.nCudaDeviceId = deviceIndex;
190- nppCtx.nMultiProcessorCount = prop.multiProcessorCount ;
191- nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor ;
192- nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock ;
193- nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock ;
194- nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major ;
195- nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor ;
196-
197- // TODO when implementing the cache logic, move these out. See other TODO
198- // below.
199- nppCtx.hStream = at::cuda::getCurrentCUDAStream (deviceIndex).stream ();
200- err = cudaStreamGetFlags (nppCtx.hStream , &nppCtx.nStreamFlags );
201- TORCH_CHECK (
202- err == cudaSuccess,
203- " cudaStreamGetFlags failed: " ,
204- cudaGetErrorString (err));
159+ nppCtx->nCudaDeviceId = nonNegativeDeviceIndex;
160+ nppCtx->nMultiProcessorCount = prop.multiProcessorCount ;
161+ nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor ;
162+ nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock ;
163+ nppCtx->nSharedMemPerBlock = prop.sharedMemPerBlock ;
164+ nppCtx->nCudaDevAttrComputeCapabilityMajor = prop.major ;
165+ nppCtx->nCudaDevAttrComputeCapabilityMinor = prop.minor ;
205166
206167 return nppCtx;
207168}
@@ -217,8 +178,10 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
217178
218179CudaDeviceInterface::~CudaDeviceInterface () {
219180 if (ctx_) {
220- addToCacheIfCacheHasCapacity (device_, ctx_);
221- av_buffer_unref (&ctx_);
181+ g_cached_hw_device_ctxs.addIfCacheHasCapacity (device_, std::move (ctx_));
182+ }
183+ if (nppCtx_) {
184+ g_cached_npp_ctxs.addIfCacheHasCapacity (device_, std::move (nppCtx_));
222185 }
223186}
224187
@@ -231,7 +194,8 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
231194 torch::Tensor dummyTensorForCudaInitialization = torch::empty (
232195 {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
233196 ctx_ = getCudaContext (device_);
234- codecContext->hw_device_ctx = av_buffer_ref (ctx_);
197+ nppCtx_ = getNppStreamContext (device_);
198+ codecContext->hw_device_ctx = av_buffer_ref (ctx_.get ());
235199 return ;
236200}
237201
@@ -310,13 +274,14 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
310274 dst = allocateEmptyHWCTensor (height, width, device_);
311275 }
312276
313- // TODO cache the NppStreamContext! It currently gets re-recated for every
314- // single frame. The cache should be per-device, similar to the existing
315- // hw_device_ctx cache. When implementing the cache logic, the
316- // NppStreamContext hStream and nStreamFlags should not be part of the cache
317- // because they may change across calls.
318- NppStreamContext nppCtx = createNppStreamContext (
319- static_cast <int >(getFFMPEGCompatibleDeviceIndex (device_)));
277+ torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex (device_);
278+ nppCtx_->hStream = at::cuda::getCurrentCUDAStream (deviceIndex).stream ();
279+ cudaError_t err =
280+ cudaStreamGetFlags (nppCtx_->hStream , &nppCtx_->nStreamFlags );
281+ TORCH_CHECK (
282+ err == cudaSuccess,
283+ " cudaStreamGetFlags failed: " ,
284+ cudaGetErrorString (err));
320285
321286 NppiSize oSizeROI = {width, height};
322287 Npp8u* yuvData[2 ] = {avFrame->data [0 ], avFrame->data [1 ]};
@@ -342,7 +307,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
342307 dst.stride (0 ),
343308 oSizeROI,
344309 bt709FullRangeColorTwist,
345- nppCtx );
310+ *nppCtx_ );
346311 } else {
347312 // If not full range, we assume studio limited range.
348313 // The color conversion matrix for BT.709 limited range should be:
@@ -359,7 +324,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
359324 static_cast <Npp8u*>(dst.data_ptr ()),
360325 dst.stride (0 ),
361326 oSizeROI,
362- nppCtx );
327+ *nppCtx_ );
363328 }
364329 } else {
365330 // TODO we're assuming BT.601 color space (and probably limited range) by
@@ -371,7 +336,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
371336 static_cast <Npp8u*>(dst.data_ptr ()),
372337 dst.stride (0 ),
373338 oSizeROI,
374- nppCtx );
339+ *nppCtx_ );
375340 }
376341 TORCH_CHECK (status == NPP_SUCCESS, " Failed to convert NV12 frame." );
377342}
0 commit comments