@@ -315,31 +315,85 @@ class BACKEND(_Enum):
315315 CUDA = _Enum_Type (2 )
316316 OPENCL = _Enum_Type (4 )
317317
318- class _clibrary (object ):
318+ def _setup ():
319+ import platform
320+ import os
321+
322+ platform_name = platform .system ()
323+
324+ try :
325+ AF_SEARCH_PATH = os .environ ['AF_PATH' ]
326+ except :
327+ AF_SEARCH_PATH = None
328+ pass
329+
330+ try :
331+ CUDA_PATH = os .environ ['CUDA_PATH' ]
332+ except :
333+ CUDA_PATH = None
334+ pass
335+
336+ CUDA_EXISTS = False
337+
338+ assert (len (platform_name ) >= 3 )
339+ if platform_name == 'Windows' or platform_name [:3 ] == 'CYG' :
340+
341+ ## Windows specific setup
342+ pre = ''
343+ post = '.dll'
344+ if platform_name == "Windows" :
345+ '''
346+ Supressing crashes caused by missing dlls
347+ http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup
348+ https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx
349+ '''
350+ ct .windll .kernel32 .SetErrorMode (0x0001 | 0x0002 )
351+
352+ if AF_SEARCH_PATH is None :
353+ AF_SEARCH_PATH = "C:/Program Files/ArrayFire/v3/"
354+
355+ if CUDA_PATH is not None :
356+ CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/bin' ) and os .path .isdir (CUDA_PATH + '/nvvm/bin/' )
357+
358+ elif platform_name == 'Darwin' :
359+
360+ ## OSX specific setup
361+ pre = 'lib'
362+ post = '.dylib'
363+
364+ if AF_SEARCH_PATH is None :
365+ AF_SEARCH_PATH = '/usr/local/'
366+
367+ if CUDA_PATH is None :
368+ CUDA_PATH = '/usr/local/cuda/'
369+
370+ CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/lib' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib' )
319371
320- def __libname (self , name ):
321- platform_name = platform .system ()
322- assert (len (platform_name ) >= 3 )
323-
324- libname = 'libaf' + name
325- if platform_name == 'Linux' :
326- libname += '.so'
327- elif platform_name == 'Darwin' :
328- libname += '.dylib'
329- elif platform_name == "Windows" or platform_name [:3 ] == "CYG" :
330- libname += '.dll'
331- libname = libname [3 :] # remove 'lib'
332- if platform_name == "Windows" :
333- '''
334- Supressing crashes caused by missing dlls
335- http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup
336- https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx
337- '''
338- ct .windll .kernel32 .SetErrorMode (0x0001 | 0x0002 );
372+ elif platform_name == 'Linux' :
373+ pre = 'lib'
374+ post = '.so'
375+
376+ if AF_SEARCH_PATH is None :
377+ AF_SEARCH_PATH = '/opt/arrayfire-3/'
378+
379+ if CUDA_PATH is None :
380+ CUDA_PATH = '/usr/local/cuda/'
381+
382+ if platform .architecture ()[0 ][:2 ] == 64 :
383+ CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/lib64' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib64' )
339384 else :
340- raise OSError (platform_name + ' not supported' )
385+ CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/lib' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib' )
386+ else :
387+ raise OSError (platform_name + ' not supported' )
341388
342- return libname
389+ return pre , post , AF_SEARCH_PATH , CUDA_EXISTS
390+
391+ class _clibrary (object ):
392+
393+ def __libname (self , name , head = 'af' ):
394+ libname = self .__pre + head + name + self .__post
395+ libname_full = self .AF_SEARCH_PATH + '/lib/' + libname
396+ return (libname , libname_full )
343397
344398 def set_unsafe (self , name ):
345399 lib = self .__clibs [name ]
@@ -348,6 +402,15 @@ def set_unsafe(self, name):
348402 self .__name = name
349403
350404 def __init__ (self ):
405+
406+ more_info_str = "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information."
407+
408+ pre , post , AF_SEARCH_PATH , CUDA_EXISTS = _setup ()
409+
410+ self .__pre = pre
411+ self .__post = post
412+ self .AF_SEARCH_PATH = AF_SEARCH_PATH
413+
351414 self .__name = None
352415
353416 self .__clibs = {'cuda' : None ,
@@ -365,18 +428,29 @@ def __init__(self):
365428 'cuda' : 2 ,
366429 'opencl' : 4 }
367430
368- # Iterate in reverse order of preference
369- for name in ('cpu' , 'opencl' , 'cuda' , '' ):
431+ # Try to pre-load forge library if it exists
432+ libnames = self .__libname ('forge' , '' )
433+ for libname in libnames :
370434 try :
371- libname = self .__libname (name )
372435 ct .cdll .LoadLibrary (libname )
373- self .__clibs [name ] = ct .CDLL (libname )
374- self .__name = name
375436 except :
376437 pass
377438
439+ # Iterate in reverse order of preference
440+ for name in ('cpu' , 'opencl' , 'cuda' , '' ):
441+ libnames = self .__libname (name )
442+ for libname in libnames :
443+ try :
444+ ct .cdll .LoadLibrary (libname )
445+ self .__clibs [name ] = ct .CDLL (libname )
446+ self .__name = name
447+ break ;
448+ except :
449+ pass
450+
378451 if (self .__name is None ):
379- raise RuntimeError ("Could not load any ArrayFire libraries" )
452+ raise RuntimeError ("Could not load any ArrayFire libraries.\n " +
453+ more_info_str )
380454
381455 def get_id (self , name ):
382456 return self .__backend_name_map [name ]
0 commit comments