3232 llvm::errs () << " '" << #expr << " ' failed with '" << name << " '\n " ; \
3333 }(expr)
3434
35- // Static initialization of CUDA context for device ordinal 0.
36- static auto InitializeCtx = [] {
35+ // Static reference to CUDA primary context for device ordinal 0.
36+ static CUcontext Context = [] {
3737 CUDA_REPORT_IF_ERROR (cuInit (/* flags=*/ 0 ));
3838 CUdevice device;
3939 CUDA_REPORT_IF_ERROR (cuDeviceGet (&device, /* ordinal=*/ 0 ));
4040 CUcontext context;
41- CUDA_REPORT_IF_ERROR (cuCtxCreate (&context, /* flags= */ 0 , device));
42- return 0 ;
41+ CUDA_REPORT_IF_ERROR (cuDevicePrimaryCtxRetain (&context, device));
42+ return context ;
4343}();
4444
45+ // Sets the `Context` for the duration of the instance and restores the previous
46+ // context on destruction.
47+ class ScopedContext {
48+ public:
49+ ScopedContext () {
50+ CUDA_REPORT_IF_ERROR (cuCtxGetCurrent (&previous));
51+ CUDA_REPORT_IF_ERROR (cuCtxSetCurrent (Context));
52+ }
53+
54+ ~ScopedContext () { CUDA_REPORT_IF_ERROR (cuCtxSetCurrent (previous)); }
55+
56+ private:
57+ CUcontext previous;
58+ };
59+
4560extern " C" CUmodule mgpuModuleLoad (void *data) {
61+ ScopedContext scopedContext;
4662 CUmodule module = nullptr ;
4763 CUDA_REPORT_IF_ERROR (cuModuleLoadData (&module , data));
4864 return module ;
@@ -66,12 +82,14 @@ extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
6682 intptr_t blockX, intptr_t blockY,
6783 intptr_t blockZ, int32_t smem, CUstream stream,
6884 void **params, void **extra) {
85+ ScopedContext scopedContext;
6986 CUDA_REPORT_IF_ERROR (cuLaunchKernel (function, gridX, gridY, gridZ, blockX,
7087 blockY, blockZ, smem, stream, params,
7188 extra));
7289}
7390
7491extern " C" CUstream mgpuStreamCreate () {
92+ ScopedContext scopedContext;
7593 CUstream stream = nullptr ;
7694 CUDA_REPORT_IF_ERROR (cuStreamCreate (&stream, CU_STREAM_NON_BLOCKING));
7795 return stream;
@@ -90,6 +108,7 @@ extern "C" void mgpuStreamWaitEvent(CUstream stream, CUevent event) {
90108}
91109
92110extern " C" CUevent mgpuEventCreate () {
111+ ScopedContext scopedContext;
93112 CUevent event = nullptr ;
94113 CUDA_REPORT_IF_ERROR (cuEventCreate (&event, CU_EVENT_DISABLE_TIMING));
95114 return event;
@@ -108,6 +127,7 @@ extern "C" void mgpuEventRecord(CUevent event, CUstream stream) {
108127}
109128
110129extern " C" void *mgpuMemAlloc (uint64_t sizeBytes, CUstream /* stream*/ ) {
130+ ScopedContext scopedContext;
111131 CUdeviceptr ptr;
112132 CUDA_REPORT_IF_ERROR (cuMemAlloc (&ptr, sizeBytes));
113133 return reinterpret_cast <void *>(ptr);
0 commit comments