1212#include <umf.h>
1313#include <umf/providers/provider_cuda.h>
1414
15+ #include "provider_cuda_internal.h"
16+ #include "utils_load_library.h"
1517#include "utils_log.h"
1618
19+ static void * cu_lib_handle = NULL ;
20+
21+ void fini_cu_global_state (void ) {
22+ if (cu_lib_handle ) {
23+ utils_close_library (cu_lib_handle );
24+ cu_lib_handle = NULL ;
25+ }
26+ }
27+
1728#if defined(UMF_NO_CUDA_PROVIDER )
1829
1930umf_result_t umfCUDAMemoryProviderParamsCreate (
@@ -88,7 +99,6 @@ umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
8899#include "utils_assert.h"
89100#include "utils_common.h"
90101#include "utils_concurrency.h"
91- #include "utils_load_library.h"
92102#include "utils_log.h"
93103#include "utils_sanitizers.h"
94104
@@ -180,37 +190,45 @@ static void init_cu_global_state(void) {
180190#else
181191 const char * lib_name = "libcuda.so" ;
182192#endif
183- // check if CUDA shared library is already loaded
184- // we pass 0 as a handle to search the global symbol table
193+ // The CUDA shared library should be already loaded by the user
194+ // of the CUDA provider. UMF just want to reuse it
195+ // and increase the reference count to the CUDA shared library.
196+ void * lib_handle =
197+ utils_open_library (lib_name , UMF_UTIL_OPEN_LIBRARY_NO_LOAD );
198+ if (!lib_handle ) {
199+ LOG_ERR ("Failed to open CUDA shared library" );
200+ Init_cu_global_state_failed = true;
201+ return ;
202+ }
185203
186204 // NOTE: some symbols defined in the lib have _vX postfixes - it is
187205 // important to load the proper version of functions
188- * (void * * )& g_cu_ops .cuMemGetAllocationGranularity =
189- utils_get_symbol_addr ( 0 , "cuMemGetAllocationGranularity" , lib_name );
206+ * (void * * )& g_cu_ops .cuMemGetAllocationGranularity = utils_get_symbol_addr (
207+ lib_handle , "cuMemGetAllocationGranularity" , lib_name );
190208 * (void * * )& g_cu_ops .cuMemAlloc =
191- utils_get_symbol_addr (0 , "cuMemAlloc_v2" , lib_name );
209+ utils_get_symbol_addr (lib_handle , "cuMemAlloc_v2" , lib_name );
192210 * (void * * )& g_cu_ops .cuMemHostAlloc =
193- utils_get_symbol_addr (0 , "cuMemHostAlloc" , lib_name );
211+ utils_get_symbol_addr (lib_handle , "cuMemHostAlloc" , lib_name );
194212 * (void * * )& g_cu_ops .cuMemAllocManaged =
195- utils_get_symbol_addr (0 , "cuMemAllocManaged" , lib_name );
213+ utils_get_symbol_addr (lib_handle , "cuMemAllocManaged" , lib_name );
196214 * (void * * )& g_cu_ops .cuMemFree =
197- utils_get_symbol_addr (0 , "cuMemFree_v2" , lib_name );
215+ utils_get_symbol_addr (lib_handle , "cuMemFree_v2" , lib_name );
198216 * (void * * )& g_cu_ops .cuMemFreeHost =
199- utils_get_symbol_addr (0 , "cuMemFreeHost" , lib_name );
217+ utils_get_symbol_addr (lib_handle , "cuMemFreeHost" , lib_name );
200218 * (void * * )& g_cu_ops .cuGetErrorName =
201- utils_get_symbol_addr (0 , "cuGetErrorName" , lib_name );
219+ utils_get_symbol_addr (lib_handle , "cuGetErrorName" , lib_name );
202220 * (void * * )& g_cu_ops .cuGetErrorString =
203- utils_get_symbol_addr (0 , "cuGetErrorString" , lib_name );
221+ utils_get_symbol_addr (lib_handle , "cuGetErrorString" , lib_name );
204222 * (void * * )& g_cu_ops .cuCtxGetCurrent =
205- utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
223+ utils_get_symbol_addr (lib_handle , "cuCtxGetCurrent" , lib_name );
206224 * (void * * )& g_cu_ops .cuCtxSetCurrent =
207- utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
225+ utils_get_symbol_addr (lib_handle , "cuCtxSetCurrent" , lib_name );
208226 * (void * * )& g_cu_ops .cuIpcGetMemHandle =
209- utils_get_symbol_addr (0 , "cuIpcGetMemHandle" , lib_name );
227+ utils_get_symbol_addr (lib_handle , "cuIpcGetMemHandle" , lib_name );
210228 * (void * * )& g_cu_ops .cuIpcOpenMemHandle =
211- utils_get_symbol_addr (0 , "cuIpcOpenMemHandle_v2" , lib_name );
229+ utils_get_symbol_addr (lib_handle , "cuIpcOpenMemHandle_v2" , lib_name );
212230 * (void * * )& g_cu_ops .cuIpcCloseMemHandle =
213- utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
231+ utils_get_symbol_addr (lib_handle , "cuIpcCloseMemHandle" , lib_name );
214232
215233 if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
216234 !g_cu_ops .cuMemHostAlloc || !g_cu_ops .cuMemAllocManaged ||
@@ -221,7 +239,10 @@ static void init_cu_global_state(void) {
221239 !g_cu_ops .cuIpcCloseMemHandle ) {
222240 LOG_FATAL ("Required CUDA symbols not found." );
223241 Init_cu_global_state_failed = true;
242+ utils_close_library (lib_handle );
243+ return ;
224244 }
245+ cu_lib_handle = lib_handle ;
225246}
226247
227248umf_result_t umfCUDAMemoryProviderParamsCreate (
@@ -327,7 +348,7 @@ static umf_result_t cu_memory_provider_initialize(void *params,
327348 utils_init_once (& cu_is_initialized , init_cu_global_state );
328349 if (Init_cu_global_state_failed ) {
329350 LOG_FATAL ("Loading CUDA symbols failed" );
330- return UMF_RESULT_ERROR_UNKNOWN ;
351+ return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE ;
331352 }
332353
333354 cu_memory_provider_t * cu_provider =
0 commit comments