@@ -18,6 +18,7 @@ struct libcu_ops {
1818 CUresult (*cuCtxCreate)(CUcontext *pctx, unsigned int flags, CUdevice dev);
1919 CUresult (*cuCtxDestroy)(CUcontext ctx);
2020 CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
21+ CUresult (*cuCtxSetCurrent)(CUcontext ctx);
2122 CUresult (*cuDeviceGet)(CUdevice *device, int ordinal);
2223 CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t size);
2324 CUresult (*cuMemFree)(CUdeviceptr dptr);
@@ -34,6 +35,7 @@ struct libcu_ops {
3435 CUpointer_attribute *attributes,
3536 void **data, CUdeviceptr ptr);
3637 CUresult (*cuStreamSynchronize)(CUstream hStream);
38+ CUresult (*cuCtxSynchronize)(void );
3739} libcu_ops;
3840
3941#if USE_DLOPEN
@@ -48,7 +50,7 @@ struct DlHandleCloser {
4850std::unique_ptr<void , DlHandleCloser> cuDlHandle = nullptr ;
4951int InitCUDAOps () {
5052#ifdef _WIN32
51- const char *lib_name = " cudart .dll" ;
53+ const char *lib_name = " nvcuda .dll" ;
5254#else
5355 const char *lib_name = " libcuda.so" ;
5456#endif
@@ -84,6 +86,12 @@ int InitCUDAOps() {
8486 fprintf (stderr, " cuCtxGetCurrent symbol not found in %s\n " , lib_name);
8587 return -1 ;
8688 }
89+ *(void **)&libcu_ops.cuCtxSetCurrent =
90+ utils_get_symbol_addr (cuDlHandle.get (), " cuCtxSetCurrent" , lib_name);
91+ if (libcu_ops.cuCtxSetCurrent == nullptr ) {
92+ fprintf (stderr, " cuCtxSetCurrent symbol not found in %s\n " , lib_name);
93+ return -1 ;
94+ }
8795 *(void **)&libcu_ops.cuDeviceGet =
8896 utils_get_symbol_addr (cuDlHandle.get (), " cuDeviceGet" , lib_name);
8997 if (libcu_ops.cuDeviceGet == nullptr ) {
@@ -153,6 +161,12 @@ int InitCUDAOps() {
153161 lib_name);
154162 return -1 ;
155163 }
164+ *(void **)&libcu_ops.cuCtxSynchronize =
165+ utils_get_symbol_addr (cuDlHandle.get (), " cuCtxSynchronize" , lib_name);
166+ if (libcu_ops.cuCtxSynchronize == nullptr ) {
167+ fprintf (stderr, " cuCtxSynchronize symbol not found in %s\n " , lib_name);
168+ return -1 ;
169+ }
156170
157171 return 0 ;
158172}
@@ -165,6 +179,7 @@ int InitCUDAOps() {
165179 libcu_ops.cuCtxCreate = cuCtxCreate;
166180 libcu_ops.cuCtxDestroy = cuCtxDestroy;
167181 libcu_ops.cuCtxGetCurrent = cuCtxGetCurrent;
182+ libcu_ops.cuCtxSetCurrent = cuCtxSetCurrent;
168183 libcu_ops.cuDeviceGet = cuDeviceGet;
169184 libcu_ops.cuMemAlloc = cuMemAlloc;
170185 libcu_ops.cuMemAllocHost = cuMemAllocHost;
@@ -176,11 +191,31 @@ int InitCUDAOps() {
176191 libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
177192 libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
178193 libcu_ops.cuStreamSynchronize = cuStreamSynchronize;
194+ libcu_ops.cuCtxSynchronize = cuCtxSynchronize;
179195
180196 return 0 ;
181197}
182198#endif // USE_DLOPEN
183199
200+ static CUresult set_context (CUcontext required_ctx, CUcontext *restore_ctx) {
201+ CUcontext current_ctx = NULL ;
202+ CUresult cu_result = libcu_ops.cuCtxGetCurrent (¤t_ctx);
203+ if (cu_result != CUDA_SUCCESS) {
204+ fprintf (stderr, " cuCtxGetCurrent() failed.\n " );
205+ return cu_result;
206+ }
207+
208+ *restore_ctx = current_ctx;
209+ if (current_ctx != required_ctx) {
210+ cu_result = libcu_ops.cuCtxSetCurrent (required_ctx);
211+ if (cu_result != CUDA_SUCCESS) {
212+ fprintf (stderr, " cuCtxSetCurrent() failed.\n " );
213+ }
214+ }
215+
216+ return cu_result;
217+ }
218+
184219static int init_cuda_lib (void ) {
185220 CUresult result = libcu_ops.cuInit (0 );
186221 if (result != CUDA_SUCCESS) {
@@ -191,8 +226,6 @@ static int init_cuda_lib(void) {
191226
192227int cuda_fill (CUcontext context, CUdevice device, void *ptr, size_t size,
193228 const void *pattern, size_t pattern_size) {
194-
195- (void )context;
196229 (void )device;
197230 (void )pattern_size;
198231
@@ -202,23 +235,40 @@ int cuda_fill(CUcontext context, CUdevice device, void *ptr, size_t size,
202235 return -1 ;
203236 }
204237
238+ // set required context
239+ CUcontext curr_context = nullptr ;
240+ set_context (context, &curr_context);
241+
205242 int ret = 0 ;
206243 CUresult res =
207244 libcu_ops.cuMemsetD32 ((CUdeviceptr)ptr, *(unsigned int *)pattern,
208245 size / sizeof (unsigned int ));
209246 if (res != CUDA_SUCCESS) {
210- fprintf (stderr, " cuMemsetD32() failed!\n " );
247+ fprintf (stderr, " cuMemsetD32(%llu, %u, %zu) failed!\n " ,
248+ (CUdeviceptr)ptr, *(unsigned int *)pattern,
249+ size / pattern_size);
250+ return -1 ;
251+ }
252+
253+ res = libcu_ops.cuCtxSynchronize ();
254+ if (res != CUDA_SUCCESS) {
255+ fprintf (stderr, " cuCtxSynchronize() failed!\n " );
211256 return -1 ;
212257 }
213258
259+ // restore context
260+ set_context (curr_context, &curr_context);
214261 return ret;
215262}
216263
217- int cuda_copy (CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
218- size_t size) {
219- (void )context;
264+ int cuda_copy (CUcontext context, CUdevice device, void *dst_ptr,
265+ const void *src_ptr, size_t size) {
220266 (void )device;
221267
268+ // set required context
269+ CUcontext curr_context = nullptr ;
270+ set_context (context, &curr_context);
271+
222272 int ret = 0 ;
223273 CUresult res =
224274 libcu_ops.cuMemcpy ((CUdeviceptr)dst_ptr, (CUdeviceptr)src_ptr, size);
@@ -227,12 +277,14 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
227277 return -1 ;
228278 }
229279
230- res = libcu_ops.cuStreamSynchronize ( 0 );
280+ res = libcu_ops.cuCtxSynchronize ( );
231281 if (res != CUDA_SUCCESS) {
232- fprintf (stderr, " cuStreamSynchronize () failed!\n " );
282+ fprintf (stderr, " cuCtxSynchronize () failed!\n " );
233283 return -1 ;
234284 }
235285
286+ // restore context
287+ set_context (curr_context, &curr_context);
236288 return ret;
237289}
238290
0 commit comments