2020#include < stdio.h>
2121#include < string.h>
2222#include < clBLAS.h>
23+ #include " mutex.h"
2324#include " AutoGemmIncludes/AutoGemmKernelSelection.h"
2425#include " GemmSpecialCases.h"
2526
2627 #include < functor.h>
2728// #include <functor_selector.h>
2829#include " xgemm.h"
2930
31+ #ifdef _WIN32
32+ // #include <thread>
33+ #else
34+ #include < pthread.h>
35+ #endif
36+
3037/* *****************************************************************************
3138 * Row major -> column major
3239 *****************************************************************************/
@@ -120,22 +127,54 @@ static char *getKernelName(cl_kernel clKernel)
120127 return kernelName;
121128}
122129
130+ typedef struct kernel_map_key_ {
131+ cl_context context; // address of context
132+ cl_device_id device; // address of device
133+ const char *kernelSource; // address of kernel source
134+ } kernel_map_key;
135+
136+ bool operator <(const kernel_map_key & l, const kernel_map_key & r) {
137+ if (l.context < r.context ) {
138+ return true ;
139+ } else if (r.context < l.context ) {
140+ return false ;
141+ }
142+ if (l.device < r.device ) {
143+ return true ;
144+ } else if (r.device < l.device ) {
145+ return false ;
146+ }
147+ if (l.kernelSource < r.kernelSource ) {
148+ return true ;
149+ } else if (r.kernelSource < l.kernelSource ) {
150+ return false ;
151+ }
152+ return false ;
153+ }
154+
155+
123156/* *****************************************************************************
124157 * Make Gemm Kernel
125158 *****************************************************************************/
126159// FIXME: This function should be returning an error.
127160void makeGemmKernel (
128- cl_kernel *clKernel,
161+ cl_kernel *clKernel, // ignored as input; returns as output only
129162 cl_command_queue clQueue,
130163 const char *kernelSource,
131164 const char *sourceBuildOptions,
132165 const unsigned char **kernelBinary,
133166 size_t *kernelBinarySize,
134167 const char *binaryBuildOptions)
135168{
136- // TODO: This will need to be converted to thread local when making clBLAS thread safe
137- typedef std::map<std::string, cl_kernel> kernel_map_t ;
138- static kernel_map_t kernel_map;
169+ typedef std::map<kernel_map_key, cl_kernel> kernel_map_t ;
170+ #if defined( _WIN32 )
171+ __declspec ( thread ) static kernel_map_t *kernel_map = 0 ;
172+ #else
173+ __thread static kernel_map_t *kernel_map = 0 ;
174+ #endif
175+ if (!kernel_map) {
176+ kernel_map = new kernel_map_t ();
177+ }
139178
140179 cl_context clContext;
141180 cl_device_id clDevice;
@@ -146,33 +185,20 @@ void makeGemmKernel(
146185 err = clGetCommandQueueInfo ( clQueue, CL_QUEUE_DEVICE, sizeof (clDevice), &clDevice, NULL );
147186 CL_CHECK (err)
148187
149- std::stringstream ss;
150- ss << clDevice << " _" << clContext;
151- std::string prefix = ss.str ();
152-
153- if (*clKernel) {
154- char *kernelName = getKernelName (*clKernel);
155- // kernel has already been built, return
156- #ifdef AUTOGEMM_PRINT_DEBUG
157- printf (" makeGemmKernel: \" %s\" already built; returning.\n " , kernelName);
158- #endif
159-
160- // Check if kernel exists for this device
161- std::string key = prefix + " _" + kernelName;
162- kernel_map_t ::iterator idx = kernel_map.find (key);
163-
164-
165- // If kernel not found for this device, set to NULL
166- if (idx == kernel_map.end ()) {
167- *clKernel = NULL ;
168- } else {
169- *clKernel = idx->second ;
170- }
171-
172- delete[] kernelName;
188+ // is kernel already compiled?
189+ kernel_map_key key;
190+ key.kernelSource = kernelSource;
191+ key.context = clContext;
192+ key.device = clDevice;
193+ kernel_map_t ::iterator idx = kernel_map->find (key);
194+ if (idx == kernel_map->end ()) {
195+ *clKernel = NULL ;
196+ } else {
197+ *clKernel = idx->second ;
198+ return ;
173199 }
174200
175- if (!*clKernel) {
201+ if (true /* !*clKernel*/ ) { // since kernel wasn't found in map
176202 // kernel has not been built, so build it (from binary, preferably)
177203 cl_program clProgram;
178204 cl_int clBinaryStatus;
@@ -244,17 +270,13 @@ void makeGemmKernel(
244270 err = clReleaseProgram (clProgram);
245271 CL_CHECK (err)
246272
247- char *kernelName = getKernelName (*clKernel);
248-
249273#ifdef AUTOGEMM_PRINT_DEBUG
250274 printf (" makeGemmKernel: \" %s\" now built; returning.\n " , kernelName);
251275#endif
252276
253- std::string key = prefix + " _" + kernelName;
254- kernel_map[key] = *clKernel;
255- delete[] kernelName;
277+ // put kernel in map
278+ (*kernel_map)[key] = *clKernel;
256279 }
257-
258280 return ;
259281}
260282
@@ -439,10 +461,10 @@ clblasGemm(
439461 size_t *colKernelBinarySize = 0 ;
440462 size_t *cornerKernelBinarySize = 0 ;
441463 const char *binaryBuildOptions = NULL ;
442- cl_kernel *tileClKernel = NULL ;
443- cl_kernel *rowClKernel = NULL ;
444- cl_kernel *colClKernel = NULL ;
445- cl_kernel *cornerClKernel = NULL ;
464+ cl_kernel *tileClKernelDummy = NULL ; // no longer used; broke thread safety
465+ cl_kernel *rowClKernelDummy = NULL ; // no longer used; broke thread safety
466+ cl_kernel *colClKernelDummy = NULL ; // no longer used; broke thread safety
467+ cl_kernel *cornerClKernelDummy = NULL ; // no longer used; broke thread safety
446468 unsigned int workGroupNumRows;
447469 unsigned int workGroupNumCols;
448470 unsigned int microTileNumRows;
@@ -467,10 +489,10 @@ clblasGemm(
467489 &colKernelBinarySize,
468490 &cornerKernelBinarySize,
469491 &binaryBuildOptions,
470- &tileClKernel ,
471- &rowClKernel ,
472- &colClKernel ,
473- &cornerClKernel ,
492+ &tileClKernelDummy ,
493+ &rowClKernelDummy ,
494+ &colClKernelDummy ,
495+ &cornerClKernelDummy ,
474496 &workGroupNumRows,
475497 &workGroupNumCols,
476498 µTileNumRows,
@@ -508,10 +530,10 @@ clblasGemm(
508530 &colKernelBinarySize,
509531 &cornerKernelBinarySize,
510532 &binaryBuildOptions,
511- &tileClKernel ,
512- &rowClKernel ,
513- &colClKernel ,
514- &cornerClKernel ,
533+ &tileClKernelDummy ,
534+ &rowClKernelDummy ,
535+ &colClKernelDummy ,
536+ &cornerClKernelDummy ,
515537 &workGroupNumRows,
516538 &workGroupNumCols,
517539 µTileNumRows,
@@ -544,10 +566,16 @@ clblasGemm(
544566/* *****************************************************************************
545567 * Build kernels
546568 *****************************************************************************/
547- if (needTileKernel) makeGemmKernel ( tileClKernel, commandQueues[0 ], tileKernelSource, sourceBuildOptions, &tileKernelBinary, tileKernelBinarySize, binaryBuildOptions);
548- if (needRowKernel) makeGemmKernel ( rowClKernel, commandQueues[0 ], rowKernelSource, sourceBuildOptions, &rowKernelBinary, rowKernelBinarySize, binaryBuildOptions);
549- if (needColKernel) makeGemmKernel ( colClKernel, commandQueues[0 ], colKernelSource, sourceBuildOptions, &colKernelBinary, colKernelBinarySize, binaryBuildOptions);
550- if (needCornerKernel) makeGemmKernel (cornerClKernel, commandQueues[0 ], cornerKernelSource, sourceBuildOptions, &cornerKernelBinary, cornerKernelBinarySize, binaryBuildOptions);
569+
570+
571+ cl_kernel tileClKernel = NULL ;
572+ cl_kernel rowClKernel = NULL ;
573+ cl_kernel colClKernel = NULL ;
574+ cl_kernel cornerClKernel = NULL ;
575+ if (needTileKernel) makeGemmKernel ( &tileClKernel, commandQueues[0 ], tileKernelSource, sourceBuildOptions, &tileKernelBinary, tileKernelBinarySize, binaryBuildOptions);
576+ if (needRowKernel) makeGemmKernel ( &rowClKernel, commandQueues[0 ], rowKernelSource, sourceBuildOptions, &rowKernelBinary, rowKernelBinarySize, binaryBuildOptions);
577+ if (needColKernel) makeGemmKernel ( &colClKernel, commandQueues[0 ], colKernelSource, sourceBuildOptions, &colKernelBinary, colKernelBinarySize, binaryBuildOptions);
578+ if (needCornerKernel) makeGemmKernel (&cornerClKernel, commandQueues[0 ], cornerKernelSource, sourceBuildOptions, &cornerKernelBinary, cornerKernelBinarySize, binaryBuildOptions);
551579 const size_t localWorkSize[2 ] = { workGroupNumRows, workGroupNumCols };
552580 unsigned int numKernelsEnqueued = 0 ;
553581
@@ -576,7 +604,7 @@ clblasGemm(
576604 if (needTileKernel) {
577605 // printf("enqueueing tile kernel\n");
578606 size_t globalWorkSize[2 ] = {(M/macroTileNumRows)*workGroupNumRows, (N/macroTileNumCols)*workGroupNumCols };
579- err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], * tileClKernel,
607+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], tileClKernel,
580608 gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
581609 globalWorkSize, localWorkSize,
582610 numEventsInWaitList, eventWaitList,
@@ -591,7 +619,7 @@ clblasGemm(
591619 if (needRowKernel) {
592620 // printf("enqueueing row kernel\n");
593621 size_t globalWorkSize[2 ] = {1 *workGroupNumRows, (N/macroTileNumCols)*workGroupNumCols };
594- err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], * rowClKernel,
622+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], rowClKernel,
595623 gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
596624 globalWorkSize, localWorkSize,
597625 numEventsInWaitList, eventWaitList,
@@ -606,7 +634,7 @@ clblasGemm(
606634 if (needColKernel) {
607635 // printf("enqueueing col kernel\n");
608636 size_t globalWorkSize[2 ] = { (M/macroTileNumRows)*workGroupNumRows, 1 *workGroupNumCols };
609- err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], * colClKernel,
637+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], colClKernel,
610638 gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
611639 globalWorkSize, localWorkSize,
612640 numEventsInWaitList, eventWaitList,
@@ -621,7 +649,7 @@ clblasGemm(
621649 if (needCornerKernel) {
622650 // printf("enqueueing corner kernel\n");
623651 size_t globalWorkSize[2 ] = { 1 *workGroupNumRows, 1 *workGroupNumCols };
624- err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], * cornerClKernel,
652+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], cornerClKernel,
625653 gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
626654 globalWorkSize, localWorkSize,
627655 numEventsInWaitList, eventWaitList,
0 commit comments