@@ -521,34 +521,27 @@ def __init__(self):
521521 except :
522522 pass
523523
524+ c_dim4 = c_dim_t * 4
525+ out = ct .c_void_p (0 )
526+ dims = c_dim4 (10 , 10 , 1 , 1 )
527+
524528 # Iterate in reverse order of preference
525529 for name in ('cpu' , 'opencl' , 'cuda' , '' ):
526530 libnames = self .__libname (name )
527531 for libname in libnames :
528532 try :
529533 ct .cdll .LoadLibrary (libname )
530534 __name = 'unified' if name == '' else name
531- self .__clibs [__name ] = ct .CDLL (libname )
532- self .__name = __name
535+ clib = ct .CDLL (libname )
536+ self .__clibs [__name ] = clib
537+ err = clib .af_randu (ct .pointer (out ), 4 , ct .pointer (dims ), Dtype .f32 .value )
538+ if (err == ERR .NONE .value ):
539+ self .__name = __name
540+ clib .af_release_array (out )
533541 break ;
534542 except :
535543 pass
536544
537- c_dim4 = c_dim_t * 4
538-
539- out = c_dim_t (0 )
540- dims = c_dim4 (10 , 10 , 10 , 10 )
541-
542- for key , value in self .__clibs :
543- err = value .af_randu (ct .pointer (out ), 4 , ct .pointer (dims ), 0 )
544- if (err == ERR .NONE .value ):
545- if (self .__name != key ):
546- self .__name = key
547- break
548- else :
549- self .__name = None
550- pass
551-
552545 if (self .__name is None ):
553546 raise RuntimeError ("Could not load any ArrayFire libraries.\n " + more_info_str )
554547
0 commit comments