@@ -55,6 +55,14 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
5555 return UMF_RESULT_ERROR_NOT_SUPPORTED ;
5656}
5757
58+ umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags (
59+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
60+ (void )hParams ;
61+ (void )flags ;
62+ LOG_ERR ("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!" );
63+ return UMF_RESULT_ERROR_NOT_SUPPORTED ;
64+ }
65+
5866umf_memory_provider_ops_t * umfCUDAMemoryProviderOps (void ) {
5967 // not supported
6068 LOG_ERR ("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!" );
@@ -89,21 +97,30 @@ typedef struct cu_memory_provider_t {
8997 CUdevice device ;
9098 umf_usm_memory_type_t memory_type ;
9199 size_t min_alignment ;
100+ unsigned int alloc_flags ;
92101} cu_memory_provider_t ;
93102
94103// CUDA Memory Provider settings struct
95104typedef struct umf_cuda_memory_provider_params_t {
96- void * cuda_context_handle ; ///< Handle to the CUDA context
97- int cuda_device_handle ; ///< Handle to the CUDA device
98- umf_usm_memory_type_t memory_type ; ///< Allocation memory type
105+ // Handle to the CUDA context
106+ void * cuda_context_handle ;
107+
108+ // Handle to the CUDA device
109+ int cuda_device_handle ;
110+
111+ // Allocation memory type
112+ umf_usm_memory_type_t memory_type ;
113+
114+ // Allocation flags for cuMemHostAlloc/cuMemAllocManaged
115+ unsigned int alloc_flags ;
99116} umf_cuda_memory_provider_params_t ;
100117
101118typedef struct cu_ops_t {
102119 CUresult (* cuMemGetAllocationGranularity )(
103120 size_t * granularity , const CUmemAllocationProp * prop ,
104121 CUmemAllocationGranularity_flags option );
105122 CUresult (* cuMemAlloc )(CUdeviceptr * dptr , size_t bytesize );
106- CUresult (* cuMemAllocHost )(void * * pp , size_t bytesize );
123+ CUresult (* cuMemHostAlloc )(void * * pp , size_t bytesize , unsigned int flags );
107124 CUresult (* cuMemAllocManaged )(CUdeviceptr * dptr , size_t bytesize ,
108125 unsigned int flags );
109126 CUresult (* cuMemFree )(CUdeviceptr dptr );
@@ -172,8 +189,8 @@ static void init_cu_global_state(void) {
172189 utils_get_symbol_addr (0 , "cuMemGetAllocationGranularity" , lib_name );
173190 * (void * * )& g_cu_ops .cuMemAlloc =
174191 utils_get_symbol_addr (0 , "cuMemAlloc_v2" , lib_name );
175- * (void * * )& g_cu_ops .cuMemAllocHost =
176- utils_get_symbol_addr (0 , "cuMemAllocHost_v2 " , lib_name );
192+ * (void * * )& g_cu_ops .cuMemHostAlloc =
193+ utils_get_symbol_addr (0 , "cuMemHostAlloc " , lib_name );
177194 * (void * * )& g_cu_ops .cuMemAllocManaged =
178195 utils_get_symbol_addr (0 , "cuMemAllocManaged" , lib_name );
179196 * (void * * )& g_cu_ops .cuMemFree =
@@ -196,7 +213,7 @@ static void init_cu_global_state(void) {
196213 utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
197214
198215 if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
199- !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
216+ !g_cu_ops .cuMemHostAlloc || !g_cu_ops .cuMemAllocManaged ||
200217 !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
201218 !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
202219 !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
@@ -225,6 +242,7 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
225242 params_data -> cuda_context_handle = NULL ;
226243 params_data -> cuda_device_handle = -1 ;
227244 params_data -> memory_type = UMF_MEMORY_TYPE_UNKNOWN ;
245+ params_data -> alloc_flags = 0 ;
228246
229247 * hParams = params_data ;
230248
@@ -275,6 +293,18 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
275293 return UMF_RESULT_SUCCESS ;
276294}
277295
296+ umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags (
297+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
298+ if (!hParams ) {
299+ LOG_ERR ("CUDA Memory Provider params handle is NULL" );
300+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
301+ }
302+
303+ hParams -> alloc_flags = flags ;
304+
305+ return UMF_RESULT_SUCCESS ;
306+ }
307+
278308static umf_result_t cu_memory_provider_initialize (void * params ,
279309 void * * provider ) {
280310 if (params == NULL ) {
@@ -325,6 +355,17 @@ static umf_result_t cu_memory_provider_initialize(void *params,
325355 cu_provider -> memory_type = cu_params -> memory_type ;
326356 cu_provider -> min_alignment = min_alignment ;
327357
358+ // If the memory type is shared (CUDA managed), the allocation flags must
359+ // be set. NOTE: we do not check here if the flags are valid -
360+ // this will be done by CUDA runtime.
361+ if (cu_params -> memory_type == UMF_MEMORY_TYPE_SHARED &&
362+ cu_params -> alloc_flags == 0 ) {
363+ // the default setting is CU_MEM_ATTACH_GLOBAL
364+ cu_provider -> alloc_flags = CU_MEM_ATTACH_GLOBAL ;
365+ } else {
366+ cu_provider -> alloc_flags = cu_params -> alloc_flags ;
367+ }
368+
328369 * provider = cu_provider ;
329370
330371 return UMF_RESULT_SUCCESS ;
@@ -381,7 +422,8 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
381422 CUresult cu_result = CUDA_SUCCESS ;
382423 switch (cu_provider -> memory_type ) {
383424 case UMF_MEMORY_TYPE_HOST : {
384- cu_result = g_cu_ops .cuMemAllocHost (resultPtr , size );
425+ cu_result =
426+ g_cu_ops .cuMemHostAlloc (resultPtr , size , cu_provider -> alloc_flags );
385427 break ;
386428 }
387429 case UMF_MEMORY_TYPE_DEVICE : {
@@ -390,7 +432,7 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
390432 }
391433 case UMF_MEMORY_TYPE_SHARED : {
392434 cu_result = g_cu_ops .cuMemAllocManaged ((CUdeviceptr * )resultPtr , size ,
393- CU_MEM_ATTACH_GLOBAL );
435+ cu_provider -> alloc_flags );
394436 break ;
395437 }
396438 default :
0 commit comments