Skip to content

Commit ff6d422

Browse files
committed
Add missing linker and mode options in config
1 parent 1cedbc7 commit ff6d422

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

pytensor/configdefaults.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def _filter_mode(val):
4747
"DEBUG_MODE",
4848
"JAX",
4949
"NUMBA",
50+
"PYTORCH",
51+
"MLX",
5052
]
5153
if val in str_options:
5254
return val
@@ -367,13 +369,25 @@ def add_compile_configvars():
367369
)
368370
del param
369371

372+
default_linker = "cvm"
373+
370374
if rc == 0 and config.cxx != "":
371375
# Keep the default linker the same as the one for the mode FAST_RUN
372-
linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
376+
linker_options = [
377+
"c|py",
378+
"py",
379+
"c",
380+
"c|py_nogc",
381+
"vm",
382+
"vm_nogc",
383+
"cvm_nogc",
384+
"numba",
385+
"jax",
386+
]
373387
else:
374388
# g++ is not present or the user disabled it,
375389
# linker should default to python only.
376-
linker_options = ["py", "vm_nogc"]
390+
linker_options = ["py", "vm", "vm_nogc", "numba", "jax"]
377391
if type(config).cxx.is_default:
378392
# If the user provided an empty value for cxx, do not warn.
379393
_logger.warning(
@@ -387,7 +401,7 @@ def add_compile_configvars():
387401
"linker",
388402
"Default linker used if the pytensor flags mode is Mode",
389403
# Not mutable because the default mode is cached after the first use.
390-
EnumStr("cvm", linker_options, mutable=False),
404+
EnumStr(default_linker, linker_options, mutable=False),
391405
in_c_key=False,
392406
)
393407

0 commit comments

Comments
 (0)