2020#include < stdio.h>
2121#include < string.h>
2222#include < clBLAS.h>
23+ #include " mutex.h"
2324#include " AutoGemmIncludes/AutoGemmKernelSelection.h"
2425#include " GemmSpecialCases.h"
2526
@@ -126,22 +127,47 @@ static char *getKernelName(cl_kernel clKernel)
126127 return kernelName;
127128}
128129
130+ typedef struct kernel_map_key_ {
131+ cl_context context; // address of context
132+ cl_device_id device; // address of device
133+ const char *kernelSource; // address of kernel source
134+ } kernel_map_key;
135+
136+ bool operator <(const kernel_map_key & l, const kernel_map_key & r) {
137+ if (l.context < r.context ) {
138+ return true ;
139+ } else if (r.context < l.context ) {
140+ return false ;
141+ }
142+ if (l.device < r.device ) {
143+ return true ;
144+ } else if (r.device < l.device ) {
145+ return false ;
146+ }
147+ if (l.kernelSource < r.kernelSource ) {
148+ return true ;
149+ } else if (r.kernelSource < l.kernelSource ) {
150+ return false ;
151+ }
152+ return false ;
153+ }
154+
155+
129156/* *****************************************************************************
130157 * Make Gemm Kernel
131158 *****************************************************************************/
132159// FIXME: This function should be returning an error.
133160void makeGemmKernel (
134- cl_kernel *clKernel,
161+ cl_kernel *clKernel, // ignored as input; returns as output
135162 cl_command_queue clQueue,
136163 const char *kernelSource,
137164 const char *sourceBuildOptions,
138165 const unsigned char **kernelBinary,
139166 size_t *kernelBinarySize,
140167 const char *binaryBuildOptions)
141168{
142- typedef std::map<std::string, cl_kernel> kernel_map_t ;
143-
144- #if defined( _WIN32 )
169+ typedef std::map<kernel_map_key, cl_kernel> kernel_map_t ;
170+ #if defined( _WIN32 )
145171 __declspec ( thread ) static kernel_map_t *kernel_map = 0 ;
146172#else
147173 __thread static kernel_map_t *kernel_map = 0 ;
@@ -159,33 +185,20 @@ void makeGemmKernel(
159185 err = clGetCommandQueueInfo ( clQueue, CL_QUEUE_DEVICE, sizeof (clDevice), &clDevice, NULL );
160186 CL_CHECK (err)
161187
162- std::stringstream ss;
163- ss << clDevice << " _" << clContext;
164- std::string prefix = ss.str ();
165-
166- if (*clKernel) {
167- char *kernelName = getKernelName (*clKernel);
168- // kernel has already been built, return
169- #ifdef AUTOGEMM_PRINT_DEBUG
170- printf (" makeGemmKernel: \" %s\" already built; returning.\n " , kernelName);
171- #endif
172-
173- // Check if kernel exists for this device
174- std::string key = prefix + " _" + kernelName;
175- kernel_map_t ::iterator idx = kernel_map->find (key);
176-
177-
178- // If kernel not found for this device, set to NULL
179- if (idx == kernel_map->end ()) {
180- *clKernel = NULL ;
181- } else {
182- *clKernel = idx->second ;
183- }
184-
185- delete[] kernelName;
188+ // is kernel already compiled?
189+ kernel_map_key key;
190+ key.kernelSource = kernelSource;
191+ key.context = clContext;
192+ key.device = clDevice;
193+ kernel_map_t ::iterator idx = kernel_map->find (key);
194+ if (idx == kernel_map->end ()) {
195+ *clKernel = NULL ;
196+ } else {
197+ *clKernel = idx->second ;
198+ return ;
186199 }
187200
188- if (!*clKernel) {
201+ if (true /* !*clKernel*/ ) { // since kernel wasn't found in map
189202 // kernel has not been built, so build it (from binary, preferably)
190203 cl_program clProgram;
191204 cl_int clBinaryStatus;
@@ -257,17 +270,13 @@ void makeGemmKernel(
257270 err = clReleaseProgram (clProgram);
258271 CL_CHECK (err)
259272
260- char *kernelName = getKernelName (*clKernel);
261-
262273#ifdef AUTOGEMM_PRINT_DEBUG
263274 printf (" makeGemmKernel: \" %s\" now built; returning.\n " , kernelName);
264275#endif
265276
266- std::string key = prefix + " _ " + kernelName;
277+ // put kernel in map
267278 (*kernel_map)[key] = *clKernel;
268- delete[] kernelName;
269279 }
270-
271280 return ;
272281}
273282
@@ -557,6 +566,11 @@ clblasGemm(
557566/* *****************************************************************************
558567 * Build kernels
559568 *****************************************************************************/
569+
570+ tileClKernel = NULL ;
571+ rowClKernel = NULL ;
572+ colClKernel = NULL ;
573+ cornerClKernel = NULL ;
560574 if (needTileKernel) makeGemmKernel ( tileClKernel, commandQueues[0 ], tileKernelSource, sourceBuildOptions, &tileKernelBinary, tileKernelBinarySize, binaryBuildOptions);
561575 if (needRowKernel) makeGemmKernel ( rowClKernel, commandQueues[0 ], rowKernelSource, sourceBuildOptions, &rowKernelBinary, rowKernelBinarySize, binaryBuildOptions);
562576 if (needColKernel) makeGemmKernel ( colClKernel, commandQueues[0 ], colKernelSource, sourceBuildOptions, &colKernelBinary, colKernelBinarySize, binaryBuildOptions);
0 commit comments