@@ -165,13 +165,22 @@ def test_flag_detection():
165165
166166@pytest .fixture (
167167 scope = "module" ,
168- params = ["mkl_intel" , "mkl_gnu" , "openblas" , "lapack" , "blas" , "no_blas" ],
168+ params = [
169+ "mkl_intel" ,
170+ "mkl_gnu" ,
171+ "accelerate" ,
172+ "openblas" ,
173+ "lapack" ,
174+ "blas" ,
175+ "no_blas" ,
176+ ],
169177)
170178def blas_libs (request ):
171179 key = request .param
172180 libs = {
173181 "mkl_intel" : ["mkl_core" , "mkl_rt" , "mkl_intel_thread" , "iomp5" , "pthread" ],
174182 "mkl_gnu" : ["mkl_core" , "mkl_rt" , "mkl_gnu_thread" , "gomp" , "pthread" ],
183+ "accelerate" : ["vecLib_placeholder" ],
175184 "openblas" : ["openblas" , "gfortran" , "gomp" , "m" ],
176185 "lapack" : ["lapack" , "blas" , "cblas" , "m" ],
177186 "blas" : ["blas" , "cblas" ],
@@ -190,25 +199,37 @@ def mock_system(request):
190199def cxx_search_dirs (blas_libs , mock_system ):
191200 libext = {"Linux" : "so" , "Windows" : "dll" , "Darwin" : "dylib" }
192201 libraries = []
202+ enabled_accelerate_framework = False
193203 with tempfile .TemporaryDirectory () as d :
194204 flags = None
195205 for lib in blas_libs :
196- lib_path = Path (d ) / f"{ lib } .{ libext [mock_system ]} "
197- lib_path .write_bytes (b"1" )
198- libraries .append (lib_path )
199- if flags is None :
200- flags = f"-l{ lib } "
206+ if lib == "vecLib_placeholder" :
207+ if mock_system != "Darwin" :
208+ flags = ""
209+ else :
210+ flags = "-framework Accelerate"
211+ enabled_accelerate_framework = True
201212 else :
202- flags += f" -l{ lib } "
213+ lib_path = Path (d ) / f"{ lib } .{ libext [mock_system ]} "
214+ lib_path .write_bytes (b"1" )
215+ libraries .append (lib_path )
216+ if flags is None :
217+ flags = f"-l{ lib } "
218+ else :
219+ flags += f" -l{ lib } "
203220 if "gomp" in blas_libs and "mkl_gnu_thread" not in blas_libs :
204221 flags += " -fopenmp"
205222 if len (blas_libs ) == 0 :
206223 flags = ""
207- yield f"libraries: ={ d } " .encode (sys .stdout .encoding ), flags
224+ yield (
225+ f"libraries: ={ d } " .encode (sys .stdout .encoding ),
226+ flags ,
227+ enabled_accelerate_framework ,
228+ )
208229
209230
210231@pytest .fixture (
211- scope = "function" , params = [False , True ], ids = ["Working_CXX" , "Broken_CXX" ]
232+ scope = "function" , params = [True , False ], ids = ["Working_CXX" , "Broken_CXX" ]
212233)
213234def cxx_search_dirs_status (request ):
214235 return request .param
@@ -219,22 +240,39 @@ def cxx_search_dirs_status(request):
219240def test_default_blas_ldflags (
220241 mock_std_lib_dirs , mock_check_mkl_openmp , cxx_search_dirs , cxx_search_dirs_status
221242):
222- cxx_search_dirs , expected_blas_ldflags = cxx_search_dirs
243+ cxx_search_dirs , expected_blas_ldflags , enabled_accelerate_framework = (
244+ cxx_search_dirs
245+ )
223246 mock_process = MagicMock ()
224247 if cxx_search_dirs_status :
225248 error_message = ""
226249 mock_process .communicate = lambda * args , ** kwargs : (cxx_search_dirs , b"" )
227250 mock_process .returncode = 0
228251 else :
252+ enabled_accelerate_framework = False
229253 error_message = "Unsupported argument -print-search-dirs"
230254 error_message_bytes = error_message .encode (sys .stderr .encoding )
231255 mock_process .communicate = lambda * args , ** kwargs : (b"" , error_message_bytes )
232256 mock_process .returncode = 1
257+
258+ def patched_compile_tmp (* args , ** kwargs ):
259+ def wrapped (test_code , tmp_prefix , flags , try_run , output ):
260+ if len (flags ) >= 2 and flags [:2 ] == ["-framework" , "Accelerate" ]:
261+ print (enabled_accelerate_framework )
262+ if enabled_accelerate_framework :
263+ return (True , True )
264+ else :
265+ return (False , False , "" , "Invalid flags -framework Accelerate" )
266+ else :
267+ return (True , True )
268+
269+ return wrapped
270+
233271 with patch ("pytensor.link.c.cmodule.subprocess_Popen" , return_value = mock_process ):
234272 with patch .object (
235273 pytensor .link .c .cmodule .GCC_compiler ,
236274 "try_compile_tmp" ,
237- return_value = ( True , True ) ,
275+ new_callable = patched_compile_tmp ,
238276 ):
239277 if cxx_search_dirs_status :
240278 assert set (default_blas_ldflags ().split (" " )) == set (
@@ -267,6 +305,9 @@ def windows_conda_libs(blas_libs):
267305 subdir .mkdir (exist_ok = True , parents = True )
268306 flags = f'-L"{ subdir } "'
269307 for lib in blas_libs :
308+ if lib == "vecLib_placeholder" :
309+ flags = ""
310+ break
270311 lib_path = subdir / f"{ lib } .dll"
271312 lib_path .write_bytes (b"1" )
272313 libraries .append (lib_path )
@@ -287,6 +328,16 @@ def test_default_blas_ldflags_conda_windows(
287328 mock_process = MagicMock ()
288329 mock_process .communicate = lambda * args , ** kwargs : (b"" , b"" )
289330 mock_process .returncode = 0
331+
332+ def patched_compile_tmp (* args , ** kwargs ):
333+ def wrapped (test_code , tmp_prefix , flags , try_run , output ):
334+ if len (flags ) >= 2 and flags [:2 ] == ["-framework" , "Accelerate" ]:
335+ return (False , False , "" , "Invalid flags -framework Accelerate" )
336+ else :
337+ return (True , True )
338+
339+ return wrapped
340+
290341 with patch ("sys.platform" , "win32" ):
291342 with patch ("sys.prefix" , mock_sys_prefix ):
292343 with patch (
@@ -295,7 +346,7 @@ def test_default_blas_ldflags_conda_windows(
295346 with patch .object (
296347 pytensor .link .c .cmodule .GCC_compiler ,
297348 "try_compile_tmp" ,
298- return_value = ( True , True ) ,
349+ new_callable = patched_compile_tmp ,
299350 ):
300351 assert set (default_blas_ldflags ().split (" " )) == set (
301352 expected_blas_ldflags .split (" " )
0 commit comments