@@ -53,8 +53,14 @@ typedef struct cu_ops_t {
5353 CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
5454 CUresult (* cuCtxGetCurrent )(CUcontext * pctx );
5555 CUresult (* cuCtxSetCurrent )(CUcontext ctx );
56+ CUresult (* cuIpcGetMemHandle )(CUipcMemHandle * pHandle , CUdeviceptr dptr );
57+ CUresult (* cuIpcOpenMemHandle )(CUdeviceptr * pdptr , CUipcMemHandle handle ,
58+ unsigned int Flags );
59+ CUresult (* cuIpcCloseMemHandle )(CUdeviceptr dptr );
5660} cu_ops_t ;
5761
62+ typedef CUipcMemHandle cu_ipc_data_t ;
63+
5864static cu_ops_t g_cu_ops ;
5965static UTIL_ONCE_FLAG cu_is_initialized = UTIL_ONCE_FLAG_INIT ;
6066static bool Init_cu_global_state_failed ;
@@ -123,12 +129,20 @@ static void init_cu_global_state(void) {
123129 utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
124130 * (void * * )& g_cu_ops .cuCtxSetCurrent =
125131 utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
132+ * (void * * )& g_cu_ops .cuIpcGetMemHandle =
133+ utils_get_symbol_addr (0 , "cuIpcGetMemHandle" , lib_name );
134+ * (void * * )& g_cu_ops .cuIpcOpenMemHandle =
135+ utils_get_symbol_addr (0 , "cuIpcOpenMemHandle_v2" , lib_name );
136+ * (void * * )& g_cu_ops .cuIpcCloseMemHandle =
137+ utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
126138
127139 if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
128140 !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
129141 !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
130142 !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
131- !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ) {
143+ !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
144+ !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
145+ !g_cu_ops .cuIpcCloseMemHandle ) {
132146 LOG_ERR ("Required CUDA symbols not found." );
133147 Init_cu_global_state_failed = true;
134148 }
@@ -404,6 +418,97 @@ static const char *cu_memory_provider_get_name(void *provider) {
404418 return "CUDA" ;
405419}
406420
421+ static umf_result_t cu_memory_provider_get_ipc_handle_size (void * provider ,
422+ size_t * size ) {
423+ if (provider == NULL || size == NULL ) {
424+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
425+ }
426+
427+ * size = sizeof (cu_ipc_data_t );
428+ return UMF_RESULT_SUCCESS ;
429+ }
430+
431+ static umf_result_t cu_memory_provider_get_ipc_handle (void * provider ,
432+ const void * ptr ,
433+ size_t size ,
434+ void * providerIpcData ) {
435+ (void )size ;
436+
437+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
438+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
439+ }
440+
441+ CUresult cu_result ;
442+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
443+
444+ cu_result = g_cu_ops .cuIpcGetMemHandle (cu_ipc_data , (CUdeviceptr )ptr );
445+ if (cu_result != CUDA_SUCCESS ) {
446+ LOG_ERR ("cuIpcGetMemHandle() failed." );
447+ return cu2umf_result (cu_result );
448+ }
449+
450+ return UMF_RESULT_SUCCESS ;
451+ }
452+
453+ static umf_result_t cu_memory_provider_put_ipc_handle (void * provider ,
454+ void * providerIpcData ) {
455+ if (provider == NULL || providerIpcData == NULL ) {
456+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
457+ }
458+
459+ return UMF_RESULT_SUCCESS ;
460+ }
461+
462+ static umf_result_t cu_memory_provider_open_ipc_handle (void * provider ,
463+ void * providerIpcData ,
464+ void * * ptr ) {
465+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
466+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
467+ }
468+
469+ cu_memory_provider_t * cu_provider = (cu_memory_provider_t * )provider ;
470+
471+ CUresult cu_result ;
472+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
473+
474+ // Remember current context and set the one from the provider
475+ CUcontext restore_ctx = NULL ;
476+ umf_result_t umf_result = set_context (cu_provider -> context , & restore_ctx );
477+ if (umf_result != UMF_RESULT_SUCCESS ) {
478+ return umf_result ;
479+ }
480+
481+ cu_result = g_cu_ops .cuIpcOpenMemHandle ((CUdeviceptr * )ptr , * cu_ipc_data ,
482+ CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS );
483+
484+ if (cu_result != CUDA_SUCCESS ) {
485+ LOG_ERR ("cuIpcOpenMemHandle() failed." );
486+ }
487+
488+ set_context (restore_ctx , & restore_ctx );
489+
490+ return cu2umf_result (cu_result );
491+ }
492+
493+ static umf_result_t
494+ cu_memory_provider_close_ipc_handle (void * provider , void * ptr , size_t size ) {
495+ (void )size ;
496+
497+ if (provider == NULL || ptr == NULL ) {
498+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
499+ }
500+
501+ CUresult cu_result ;
502+
503+ cu_result = g_cu_ops .cuIpcCloseMemHandle ((CUdeviceptr )ptr );
504+ if (cu_result != CUDA_SUCCESS ) {
505+ LOG_ERR ("cuIpcCloseMemHandle() failed." );
506+ return cu2umf_result (cu_result );
507+ }
508+
509+ return UMF_RESULT_SUCCESS ;
510+ }
511+
407512static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
408513 .version = UMF_VERSION_CURRENT ,
409514 .initialize = cu_memory_provider_initialize ,
@@ -420,12 +525,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
420525 .ext.purge_force = cu_memory_provider_purge_force,
421526 .ext.allocation_merge = cu_memory_provider_allocation_merge,
422527 .ext.allocation_split = cu_memory_provider_allocation_split,
528+ */
423529 .ipc .get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size ,
424530 .ipc .get_ipc_handle = cu_memory_provider_get_ipc_handle ,
425531 .ipc .put_ipc_handle = cu_memory_provider_put_ipc_handle ,
426532 .ipc .open_ipc_handle = cu_memory_provider_open_ipc_handle ,
427533 .ipc .close_ipc_handle = cu_memory_provider_close_ipc_handle ,
428- */
429534};
430535
431536umf_memory_provider_ops_t * umfCUDAMemoryProviderOps (void ) {
0 commit comments