Skip to content

Commit 41f0694

Browse files
committed
Manual control of numba caching
1 parent 4802bbb commit 41f0694

File tree

9 files changed

+635
-64
lines changed

9 files changed

+635
-64
lines changed

doc/extending/creating_a_numba_jax_op.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ Here's an example for :class:`DimShuffle`:
228228
# E No match.
229229
# ...(on this line)...
230230
# E shuffle_shape = res.shape[: len(shuffle)]
231-
@numba_basic.numba_njit(inline="always")
231+
@numba_basic.numba_njit
232232
def dimshuffle(x):
233233
return dimshuffle_inner(np.asarray(x), shuffle)
234234

pytensor/bin/pytensor_cache.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
import shutil
34
import sys
45
from pathlib import Path
56

@@ -74,7 +75,10 @@ def main():
7475
'You can also call "pytensor-cache purge" to '
7576
"remove everything from that directory."
7677
)
77-
_logger.debug(f"Remaining elements ({len(items)}): {', '.join(items)}")
78+
_logger.debug(f"Remaining elements ({len(items)}): {items}")
79+
numba_cache_dir: Path = config.base_compiledir / "numba"
80+
shutil.rmtree(numba_cache_dir, ignore_errors=True)
81+
7882
elif sys.argv[1] == "list":
7983
pytensor.compile.compiledir.print_compiledir_content()
8084
elif sys.argv[1] == "cleanup":
@@ -86,6 +90,8 @@ def main():
8690
print("Lock successfully removed!")
8791
elif sys.argv[1] == "purge":
8892
pytensor.compile.compiledir.compiledir_purge()
93+
numba_cache_dir: Path = config.base_compiledir / "numba"
94+
shutil.rmtree(numba_cache_dir, ignore_errors=True)
8995
elif sys.argv[1] == "basecompiledir":
9096
# Simply print the base_compiledir
9197
print(pytensor.config.base_compiledir)

pytensor/link/numba/cache.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from collections.abc import Callable
2+
from hashlib import sha256
3+
from pathlib import Path
4+
from pickle import dump
5+
from tempfile import NamedTemporaryFile
6+
from typing import Any
7+
from weakref import WeakKeyDictionary
8+
9+
from numba.core.caching import CacheImpl, _CacheLocator
10+
11+
from pytensor.configdefaults import config
12+
13+
14+
NUMBA_CACHE_PATH = config.base_compiledir / "numba"
15+
NUMBA_CACHE_PATH.mkdir(exist_ok=True)
16+
CACHED_SRC_FUNCTIONS: WeakKeyDictionary[Callable, str] = WeakKeyDictionary()
17+
18+
19+
class NumbaPyTensorCacheLocator(_CacheLocator):
20+
"""Locator for Numba functions defined from PyTensor-generated source code.
21+
22+
It uses an internally-defined hash to disambiguate functions.
23+
24+
Functions returned by the PyTensor dispatchers are cached in the CACHED_SRC_FUNCTIONS
25+
weakref dictionary when `compile_numba_function_src` is called with a `cache_key`.
26+
When numba later attempts to find a cache for such a function, this locator gets triggered
27+
and directs numba to the PyTensor Numba cache directory, using the provided hash as disambiguator.
28+
29+
It is not necessary that the python functions be cached by the dispatchers.
30+
As long as the key is the same, numba will be directed to the same cache entry, even if the function is fresh.
31+
Conversely, if the function changed but the key is the same, numba will still use the old cache.
32+
"""
33+
34+
def __init__(self, py_func, py_file, hash):
35+
self._py_func = py_func
36+
self._py_file = py_file
37+
self._hash = hash
38+
39+
def ensure_cache_path(self):
40+
"""We ensured this when the module was loaded.
41+
42+
It's too slow to run every time a cache is needed.
43+
"""
44+
pass
45+
46+
def get_cache_path(self):
47+
"""Return the directory the function is cached in."""
48+
return NUMBA_CACHE_PATH
49+
50+
def get_source_stamp(self):
51+
"""Get a timestamp representing the source code's freshness.
52+
Can return any picklable Python object.
53+
54+
This can be used to invalidate all caches from previous PyTensor releases.
55+
"""
56+
return 0
57+
58+
def get_disambiguator(self):
59+
"""Get a string disambiguator for this locator's function.
60+
It should allow disambiguating different but similarly-named functions.
61+
"""
62+
return self._hash
63+
64+
@classmethod
65+
def from_function(cls, py_func, py_file):
66+
"""Create a locator instance for functions stored in CACHED_SRC_FUNCTIONS."""
67+
if config.numba__cache and py_func in CACHED_SRC_FUNCTIONS:
68+
return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func])
69+
70+
71+
# Register our locator at the front of Numba's locator list
72+
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)
73+
74+
75+
def hash_from_pickle_dump(obj: Any) -> str:
76+
"""Create a sha256 hash from the pickle dump of an object."""
77+
78+
# Stream pickle directly into the hasher to avoid a large temporary bytes object
79+
hasher = sha256()
80+
81+
class HashFile:
82+
def write(self, b):
83+
hasher.update(b)
84+
85+
dump(obj, HashFile())
86+
return hasher.hexdigest()
87+
88+
89+
def compile_numba_function_src(
90+
src: str,
91+
function_name: str,
92+
global_env: dict[Any, Any] | None = None,
93+
local_env: dict[Any, Any] | None = None,
94+
write_to_disk: bool = False,
95+
cache_key: str | None = None,
96+
) -> Callable:
97+
"""Compile (and optionally cache) a function from source code for use with Numba.
98+
99+
This function compiles the provided source code string into a Python function
100+
with the specified name. If `store_to_disk` is True, the source code is written
101+
to a temporary file before compilation. The compiled function is then executed
102+
in the provided global and local environments.
103+
104+
If a `cache_key` is provided the function is registered in a `CACHED_SRC_FUNCTIONS`
105+
weak reference dictionary, to be used by the `NumbaPyTensorCacheLocator` for caching.
106+
107+
"""
108+
if write_to_disk:
109+
with NamedTemporaryFile(delete=False) as f:
110+
filename = f.name
111+
f.write(src.encode())
112+
else:
113+
filename = "<string>"
114+
115+
if global_env is None:
116+
global_env = {}
117+
118+
if local_env is None:
119+
local_env = {}
120+
121+
mod_code = compile(src, filename, mode="exec")
122+
exec(mod_code, global_env, local_env)
123+
124+
res = local_env[function_name]
125+
res.__source__ = src
126+
127+
if cache_key is not None:
128+
CACHED_SRC_FUNCTIONS[res] = cache_key
129+
130+
return res # type: ignore

0 commit comments

Comments
 (0)