Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import pytest


# Using pytest_plugins causes `tests/link/c/test_cmodule.py::test_cache_versioning` to fail
# pytest_plugins = ["tests.fixtures"]


def pytest_sessionstart(session):
os.environ["PYTENSOR_FLAGS"] = ",".join(
[
Expand Down
2 changes: 1 addition & 1 deletion doc/extending/creating_a_numba_jax_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ Here's an example for :class:`DimShuffle`:
# E No match.
# ...(on this line)...
# E shuffle_shape = res.shape[: len(shuffle)]
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit
def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle)

Expand Down
1 change: 0 additions & 1 deletion environment-osx-arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ dependencies:
- diff-cover
- mypy
- types-setuptools
- scipy-stubs
- pytest
- pytest-cov
- pytest-xdist
Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ dependencies:
- diff-cover
- mypy
- types-setuptools
- scipy-stubs
- pytest
- pytest-cov
- pytest-xdist
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ lines-after-imports = 2
"pytensor/misc/check_duplicate_key.py" = ["T201"]
"pytensor/misc/check_blas.py" = ["T201"]
"pytensor/bin/pytensor_cache.py" = ["T201"]
# For the tests we skip because `pytest.importorskip` is used:
# For the tests we skip `E402` because `pytest.importorskip` is used:
"tests/link/jax/test_scalar.py" = ["E402"]
"tests/link/jax/test_tensor_basic.py" = ["E402"]
"tests/link/numba/test_basic.py" = ["E402"]
Expand Down
8 changes: 7 additions & 1 deletion pytensor/bin/pytensor_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import shutil
import sys
from pathlib import Path

Expand Down Expand Up @@ -74,7 +75,10 @@ def main():
'You can also call "pytensor-cache purge" to '
"remove everything from that directory."
)
_logger.debug(f"Remaining elements ({len(items)}): {', '.join(items)}")
_logger.debug(f"Remaining elements ({len(items)}): {items}")
numba_cache_dir: Path = config.base_compiledir / "numba"
shutil.rmtree(numba_cache_dir, ignore_errors=True)

elif sys.argv[1] == "list":
pytensor.compile.compiledir.print_compiledir_content()
elif sys.argv[1] == "cleanup":
Expand All @@ -86,6 +90,8 @@ def main():
print("Lock successfully removed!")
elif sys.argv[1] == "purge":
pytensor.compile.compiledir.compiledir_purge()
numba_cache_dir: Path = config.base_compiledir / "numba"
shutil.rmtree(numba_cache_dir, ignore_errors=True)
elif sys.argv[1] == "basecompiledir":
# Simply print the base_compiledir
print(pytensor.config.base_compiledir)
Expand Down
5 changes: 5 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
)

C = Mode("c", "fast_run")
C_VM = Mode("cvm", "fast_run")

NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
Expand Down Expand Up @@ -524,6 +527,8 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
predefined_modes = {
"FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"C": C,
"C_VM": C_VM,
"JAX": JAX,
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
Expand Down
22 changes: 19 additions & 3 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ def _filter_mode(val):
"NanGuardMode",
"FAST_COMPILE",
"DEBUG_MODE",
"CVM",
"C",
"JAX",
"NUMBA",
"PYTORCH",
"MLX",
]
if val in str_options:
return val
Expand Down Expand Up @@ -367,13 +371,25 @@ def add_compile_configvars():
)
del param

default_linker = "cvm"

if rc == 0 and config.cxx != "":
# Keep the default linker the same as the one for the mode FAST_RUN
linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
linker_options = [
"c|py",
"py",
"c",
"c|py_nogc",
"vm",
"vm_nogc",
"cvm_nogc",
"numba",
"jax",
]
else:
# g++ is not present or the user disabled it,
# linker should default to python only.
linker_options = ["py", "vm_nogc"]
linker_options = ["py", "vm", "vm_nogc", "numba", "jax"]
if type(config).cxx.is_default:
# If the user provided an empty value for cxx, do not warn.
_logger.warning(
Expand All @@ -387,7 +403,7 @@ def add_compile_configvars():
"linker",
"Default linker used if the pytensor flags mode is Mode",
# Not mutable because the default mode is cached after the first use.
EnumStr("cvm", linker_options, mutable=False),
EnumStr(default_linker, linker_options, mutable=False),
in_c_key=False,
)

Expand Down
4 changes: 3 additions & 1 deletion pytensor/configparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,9 @@ def __get__(self, cls, type_, delete_key=False):
f"The config parameter '{self.name}' was registered on a different instance of the PyTensorConfigParser."
f" It is not accessible through the instance with id '{id(cls)}' because of safeguarding."
)
if not hasattr(self, "val"):
try:
return self.val
except AttributeError:
try:
val_str = cls.fetch_val_for_key(self.name, delete_key=delete_key)
self.is_default = False
Expand Down
130 changes: 130 additions & 0 deletions pytensor/link/numba/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from collections.abc import Callable
from hashlib import sha256
from pathlib import Path
from pickle import dump
from tempfile import NamedTemporaryFile
from typing import Any
from weakref import WeakKeyDictionary

from numba.core.caching import CacheImpl, _CacheLocator

from pytensor.configdefaults import config


NUMBA_CACHE_PATH = config.base_compiledir / "numba"
NUMBA_CACHE_PATH.mkdir(exist_ok=True)
CACHED_SRC_FUNCTIONS: WeakKeyDictionary[Callable, str] = WeakKeyDictionary()


class NumbaPyTensorCacheLocator(_CacheLocator):
"""Locator for Numba functions defined from PyTensor-generated source code.

It uses an internally-defined hash to disambiguate functions.

Functions returned by the PyTensor dispatchers are cached in the CACHED_SRC_FUNCTIONS
weakref dictionary when `compile_numba_function_src` is called with a `cache_key`.
When numba later attempts to find a cache for such a function, this locator gets triggered
and directs numba to the PyTensor Numba cache directory, using the provided hash as disambiguator.

It is not necessary that the python functions be cached by the dispatchers.
As long as the key is the same, numba will be directed to the same cache entry, even if the function is fresh.
Conversely, if the function changed but the key is the same, numba will still use the old cache.
"""

def __init__(self, py_func, py_file, hash):
self._py_func = py_func
self._py_file = py_file
self._hash = hash

def ensure_cache_path(self):
"""We ensured this when the module was loaded.

It's too slow to run every time a cache is needed.
"""
pass

def get_cache_path(self):
"""Return the directory the function is cached in."""
return NUMBA_CACHE_PATH

def get_source_stamp(self):
"""Get a timestamp representing the source code's freshness.
Can return any picklable Python object.

This can be used to invalidate all caches from previous PyTensor releases.
"""
return 0

def get_disambiguator(self):
"""Get a string disambiguator for this locator's function.
It should allow disambiguating different but similarly-named functions.
"""
return self._hash

@classmethod
def from_function(cls, py_func, py_file):
"""Create a locator instance for functions stored in CACHED_SRC_FUNCTIONS."""
if config.numba__cache and py_func in CACHED_SRC_FUNCTIONS:
return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func])


# Register our locator at the front of Numba's locator list
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)


def hash_from_pickle_dump(obj: Any) -> str:
"""Create a sha256 hash from the pickle dump of an object."""

# Stream pickle directly into the hasher to avoid a large temporary bytes object
hasher = sha256()

class HashFile:
def write(self, b):
hasher.update(b)

dump(obj, HashFile())
return hasher.hexdigest()


def compile_numba_function_src(
src: str,
function_name: str,
global_env: dict[Any, Any] | None = None,
local_env: dict[Any, Any] | None = None,
write_to_disk: bool = False,
cache_key: str | None = None,
) -> Callable:
"""Compile (and optionally cache) a function from source code for use with Numba.

This function compiles the provided source code string into a Python function
with the specified name. If `store_to_disk` is True, the source code is written
to a temporary file before compilation. The compiled function is then executed
in the provided global and local environments.

If a `cache_key` is provided the function is registered in a `CACHED_SRC_FUNCTIONS`
weak reference dictionary, to be used by the `NumbaPyTensorCacheLocator` for caching.

"""
if write_to_disk:
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(src.encode())
else:
filename = "<string>"

if global_env is None:
global_env = {}

if local_env is None:
local_env = {}

mod_code = compile(src, filename, mode="exec")
exec(mod_code, global_env, local_env)

res = local_env[function_name]
res.__source__ = src

if cache_key is not None:
CACHED_SRC_FUNCTIONS[res] = cache_key

return res # type: ignore
1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# Load dispatch specializations
import pytensor.link.numba.dispatch.blockwise
import pytensor.link.numba.dispatch.compile_ops
import pytensor.link.numba.dispatch.elemwise
import pytensor.link.numba.dispatch.extra_ops
import pytensor.link.numba.dispatch.nlinalg
Expand Down
Loading
Loading