@@ -179,6 +179,14 @@ def __init__(self):
179179 self .incStr += "#include <CL/cl.h>\n "
180180 self .incStr += "#endif\n "
181181 self .incStr += "\n "
182+ self .incStr += "#ifdef __cplusplus\n "
183+ self .incStr += "extern \" C\" {\n "
184+ self .incStr += "#endif\n "
185+ self .incStr += " void initAutoGemmClKernels(void);\n " ;
186+ self .incStr += "#ifdef __cplusplus\n "
187+ self .incStr += "}\n " ;
188+ self .incStr += "#endif\n "
189+ self .incStr += "\n " ;
182190
183191 self .cppName = Common .getIncludePath () + "AutoGemmClKernels.cpp"
184192 self .cppFile = open (self .cppName , "w" )
@@ -190,29 +198,50 @@ def __init__(self):
190198 self .cppStr += "#endif\n "
191199 self .cppStr += "\n "
192200
201+
202+ self .initFunction = "" ;
203+ self .initFunction += "extern \" C\" {\n " ;
204+ self .initFunction += " void initAutoGemmClKernels(void);\n " ;
205+ self .initFunction += "}\n " ;
206+ self .initFunction += "\n " ;
207+ self .initFunction += "void initAutoGemmClKernels(void) {\n " ;
208+ self .defines = "" ;
209+
193210 def addKernel (self , kernel ):
194- kernelName = kernel .getName ()
195- self .incStr += "extern cl_kernel %s_clKernel;\n " % kernelName
196- self .cppStr += "cl_kernel %s_clKernel = NULL;\n " % kernelName
197- kernelName = kernel .getRowName ()
198- self .incStr += "extern cl_kernel %s_clKernel;\n " % kernelName
199- self .cppStr += "cl_kernel %s_clKernel = NULL;\n " % kernelName
200- kernelName = kernel .getColName ()
201- self .incStr += "extern cl_kernel %s_clKernel;\n " % kernelName
202- self .cppStr += "cl_kernel %s_clKernel = NULL;\n " % kernelName
203- kernelName = kernel .getCornerName ()
204- self .incStr += "extern cl_kernel %s_clKernel;\n " % kernelName
205- self .cppStr += "cl_kernel %s_clKernel = NULL;\n " % kernelName
211+ kernelNames = [
212+ kernel .getName (),
213+ kernel .getRowName (),
214+ kernel .getColName (),
215+ kernel .getCornerName ()
216+ ]
217+ for kernelName in kernelNames :
218+ self .incStr += "extern cl_kernel %s_clKernel;\n " % kernelName
219+
220+ self .defines += "cl_kernel %s_clKernel = NULL;\n " % kernelName
221+
222+ self .initFunction += " if(%s_clKernel != NULL) {\n " % kernelName
223+ self .initFunction += " clReleaseKernel(%s_clKernel);\n " % kernelName
224+ self .initFunction += " %s_clKernel = NULL;\n " % kernelName
225+ self .initFunction += " }\n "
206226
207227 self .incFile .write ( self .incStr )
208228 self .incStr = ""
209- self .cppFile .write ( self .cppStr )
210- self .cppStr = ""
229+ # self.cppFile.write( self.cppStr )
230+ # self.cppStr = ""
211231
212232 def writeToFile (self ):
213233 self .incFile .write ( self .incStr )
214234 self .incFile .write ( "\n #endif\n " )
215235 self .incFile .close ()
236+
237+ self .initFunction += "}\n " ;
238+ self .cppStr += self .defines + "\n " ;
239+ self .defines = "" ;
240+ self .cppStr += self .initFunction + "\n " ;
241+ self .initFunction = "" ;
242+
243+ # self.cppStr += "\n";
244+ # self.cppStr += "initAutoGemmClKernels();\n";
216245 self .cppFile .write ( self .cppStr )
217246 self .cppFile .close ()
218247
0 commit comments