@@ -51,6 +51,8 @@ typedef struct cu_ops_t {
5151
5252 CUresult (* cuGetErrorName )(CUresult error , const char * * pStr );
5353 CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
54+ CUresult (* cuCtxGetCurrent )(CUcontext * pctx );
55+ CUresult (* cuCtxSetCurrent )(CUcontext ctx );
5456} cu_ops_t ;
5557
5658static cu_ops_t g_cu_ops ;
@@ -117,11 +119,16 @@ static void init_cu_global_state(void) {
117119 utils_get_symbol_addr (0 , "cuGetErrorName" , lib_name );
118120 * (void * * )& g_cu_ops .cuGetErrorString =
119121 utils_get_symbol_addr (0 , "cuGetErrorString" , lib_name );
122+ * (void * * )& g_cu_ops .cuCtxGetCurrent =
123+ utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
124+ * (void * * )& g_cu_ops .cuCtxSetCurrent =
125+ utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
120126
121127 if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
122128 !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
123129 !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
124- !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ) {
130+ !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
131+ !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ) {
125132 LOG_ERR ("Required CUDA symbols not found." );
126133 Init_cu_global_state_failed = true;
127134 }
@@ -190,6 +197,31 @@ static void cu_memory_provider_finalize(void *provider) {
190197 umf_ba_global_free (provider );
191198}
192199
200+ /*
201+ * This function is used by the CUDA provider to make sure that
202+ * the required context is set. If the current context is
203+ * not the required one, it will be saved in restore_ctx.
204+ */
205+ static inline umf_result_t set_context (CUcontext required_ctx ,
206+ CUcontext * restore_ctx ) {
207+ CUcontext current_ctx = NULL ;
208+ CUresult cu_result = g_cu_ops .cuCtxGetCurrent (& current_ctx );
209+ if (cu_result != CUDA_SUCCESS ) {
210+ LOG_ERR ("cuCtxGetCurrent() failed." );
211+ return cu2umf_result (cu_result );
212+ }
213+ * restore_ctx = current_ctx ;
214+ if (current_ctx != required_ctx ) {
215+ cu_result = g_cu_ops .cuCtxSetCurrent (required_ctx );
216+ if (cu_result != CUDA_SUCCESS ) {
217+ LOG_ERR ("cuCtxSetCurrent() failed." );
218+ return cu2umf_result (cu_result );
219+ }
220+ }
221+
222+ return UMF_RESULT_SUCCESS ;
223+ }
224+
193225static umf_result_t cu_memory_provider_alloc (void * provider , size_t size ,
194226 size_t alignment ,
195227 void * * resultPtr ) {
@@ -205,6 +237,14 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
205237 return UMF_RESULT_ERROR_NOT_SUPPORTED ;
206238 }
207239
240+ // Remember current context and set the one from the provider
241+ CUcontext restore_ctx = NULL ;
242+ umf_result_t umf_result = set_context (cu_provider -> context , & restore_ctx );
243+ if (umf_result != UMF_RESULT_SUCCESS ) {
244+ LOG_ERR ("Failed to set CUDA context, ret = %d" , umf_result );
245+ return umf_result ;
246+ }
247+
208248 CUresult cu_result = CUDA_SUCCESS ;
209249 switch (cu_provider -> memory_type ) {
210250 case UMF_MEMORY_TYPE_HOST : {
@@ -224,17 +264,29 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
224264 // this shouldn't happen as we check the memory_type settings during
225265 // the initialization
226266 LOG_ERR ("unsupported USM memory type" );
267+ assert (false);
227268 return UMF_RESULT_ERROR_UNKNOWN ;
228269 }
229270
271+ umf_result = set_context (restore_ctx , & restore_ctx );
272+ if (umf_result != UMF_RESULT_SUCCESS ) {
273+ LOG_ERR ("Failed to restore CUDA context, ret = %d" , umf_result );
274+ }
275+
276+ umf_result = cu2umf_result (cu_result );
277+ if (umf_result != UMF_RESULT_SUCCESS ) {
278+ LOG_ERR ("Failed to allocate memory, cu_result = %d, ret = %d" ,
279+ cu_result , umf_result );
280+ return umf_result ;
281+ }
282+
230283 // check the alignment
231284 if (alignment > 0 && ((uintptr_t )(* resultPtr ) % alignment ) != 0 ) {
232285 cu_memory_provider_free (provider , * resultPtr , size );
233286 LOG_ERR ("unsupported alignment size" );
234287 return UMF_RESULT_ERROR_INVALID_ALIGNMENT ;
235288 }
236-
237- return cu2umf_result (cu_result );
289+ return umf_result ;
238290}
239291
240292static umf_result_t cu_memory_provider_free (void * provider , void * ptr ,
0 commit comments