Skip to content

Commit f09517a

Browse files
committed
Implement Numba VM with caching
1 parent 051b32d commit f09517a

File tree

9 files changed

+1154
-46
lines changed

9 files changed

+1154
-46
lines changed

notebooks/numba_cache.ipynb

Lines changed: 937 additions & 0 deletions
Large diffs are not rendered by default.

pytensor/compile/mode.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
472472
),
473473
)
474474

475+
NUMBA_VM = Mode(
476+
NumbaLinker(vm=True),
477+
NUMBA._optimizer,
478+
)
479+
475480
JAX = Mode(
476481
JAXLinker(),
477482
RewriteDatabaseQuery(
@@ -510,6 +515,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
510515
"FAST_RUN": FAST_RUN,
511516
"JAX": JAX,
512517
"NUMBA": NUMBA,
518+
"NUMBA_VM": NUMBA_VM,
513519
"PYTORCH": PYTORCH,
514520
}
515521

pytensor/link/numba/cache.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from collections.abc import Callable
2+
from pathlib import Path
3+
from tempfile import NamedTemporaryFile, TemporaryFile
4+
from typing import Any
5+
6+
from numba.core.caching import CacheImpl, _CacheLocator
7+
8+
from pytensor import config
9+
10+
11+
NUMBA_PYTENSOR_CACHE_ENABLED = True
12+
COMPILED_SRC_FUNCTIONS = {}
13+
14+
15+
def compile_and_cache_numba_function_src(
16+
src: str,
17+
function_name: str,
18+
global_env: dict[Any, Any] | None = None,
19+
local_env: dict[Any, Any] | None = None,
20+
key: str | None = None,
21+
) -> Callable:
22+
if key is not None:
23+
numba_path = config.base_compiledir / "numba"
24+
numba_path.mkdir(exist_ok=True)
25+
filename = numba_path / key
26+
with filename.open("wb") as f:
27+
f.write(src.encode())
28+
else:
29+
with NamedTemporaryFile(delete=False) as f:
30+
filename = f.name
31+
f.write(src.encode())
32+
33+
if global_env is None:
34+
global_env = {}
35+
36+
if local_env is None:
37+
local_env = {}
38+
39+
mod_code = compile(src, filename, mode="exec")
40+
exec(mod_code, global_env, local_env)
41+
42+
res = local_env[function_name]
43+
res.__source__ = src # type: ignore
44+
45+
if key is not None:
46+
COMPILED_SRC_FUNCTIONS[res] = key
47+
return res
48+
49+
50+
class NumbaPyTensorCacheLocator(_CacheLocator):
51+
def __init__(self, py_func, py_file, hash):
52+
# print(f"New locator {py_func=}, {py_file=}, {hash=}")
53+
self._py_func = py_func
54+
self._py_file = py_file
55+
self._hash = hash
56+
# src_hash = hash(pytensor_loader._module_sources[self._py_file])
57+
# self._hash = hash((src_hash, py_file, pytensor.__version__))
58+
59+
def ensure_cache_path(self):
60+
# print("ensure_cache_path called")
61+
path = self.get_cache_path()
62+
path.mkdir(exist_ok=True)
63+
# Ensure the directory is writable by trying to write a temporary file
64+
TemporaryFile(dir=path).close()
65+
66+
def get_cache_path(self):
67+
"""
68+
Return the directory the function is cached in.
69+
"""
70+
# print("get_cache_path called")
71+
return self._py_file
72+
73+
def get_source_stamp(self):
74+
"""
75+
Get a timestamp representing the source code's freshness.
76+
Can return any picklable Python object.
77+
"""
78+
return 0
79+
# print("get_source_stamp called")
80+
return self._hash
81+
82+
def get_disambiguator(self):
83+
"""
84+
Get a string disambiguator for this locator's function.
85+
It should allow disambiguating different but similarly-named functions.
86+
"""
87+
# print("get_disambiguator called")
88+
return self._hash
89+
90+
@classmethod
91+
def from_function(cls, py_func, py_file):
92+
"""
93+
Create a locator instance for the given function located in the given file.
94+
"""
95+
# py_file = Path(py_file).parent
96+
# if py_file == (config.base_compiledir / "numba"):
97+
if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in COMPILED_SRC_FUNCTIONS:
98+
# print(f"Applies to {py_file}")
99+
return cls(py_func, Path(py_file).parent, COMPILED_SRC_FUNCTIONS[py_func])
100+
101+
102+
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)

pytensor/link/numba/dispatch/basic.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from numba import types
1515
from numba.core.errors import NumbaWarning, TypingError
1616
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
17-
from numba.extending import box, overload
17+
from numba.extending import box, overload, register_jitable as _register_jitable
1818

1919
from pytensor import In, config
2020
from pytensor.compile import NUMBA
@@ -50,10 +50,11 @@ def global_numba_func(func):
5050
return func
5151

5252

53-
def numba_njit(*args, fastmath=None, **kwargs):
54-
kwargs.setdefault("cache", config.numba__cache)
55-
kwargs.setdefault("no_cpython_wrapper", True)
56-
kwargs.setdefault("no_cfunc_wrapper", True)
53+
def numba_njit(*args, fastmath=None, register_jitable: bool = False, **kwargs):
54+
kwargs.setdefault("cache", True)
55+
kwargs.setdefault("no_cpython_wrapper", False)
56+
kwargs.setdefault("no_cfunc_wrapper", False)
57+
# print(kwargs)
5758
if fastmath is None:
5859
if config.numba__fastmath:
5960
# Opinionated default on fastmath flags
@@ -81,10 +82,11 @@ def numba_njit(*args, fastmath=None, **kwargs):
8182
category=NumbaWarning,
8283
)
8384

85+
func = _register_jitable if register_jitable else numba.njit
8486
if len(args) > 0 and callable(args[0]):
85-
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])
86-
87-
return numba.njit(*args, fastmath=fastmath, **kwargs)
87+
return func(*args[1:], fastmath=fastmath, **kwargs)(args[0])
88+
else:
89+
return func(*args, fastmath=fastmath, **kwargs)
8890

8991

9092
def numba_vectorize(*args, **kwargs):

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,17 @@
77
from numpy.lib.stride_tricks import as_strided
88

99
from pytensor.graph.op import Op
10+
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
1011
from pytensor.link.numba.dispatch import basic as numba_basic
1112
from pytensor.link.numba.dispatch.basic import (
1213
numba_funcify,
1314
numba_njit,
1415
)
1516
from pytensor.link.numba.dispatch.vectorize_codegen import (
16-
_jit_options,
1717
_vectorized,
1818
encode_literals,
1919
store_core_outputs,
2020
)
21-
from pytensor.link.utils import compile_function_src
2221
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
2322
from pytensor.scalar.basic import (
2423
AND,
@@ -237,7 +236,7 @@ def {careduce_fn_name}(x):
237236
careduce_def_src += "\n\n"
238237
careduce_def_src += indent(f"return {return_obj}", " " * 4)
239238

240-
careduce_fn = compile_function_src(
239+
careduce_fn = compile_and_cache_numba_function_src(
241240
careduce_def_src, careduce_fn_name, {**globals(), **global_env}
242241
)
243242

@@ -264,19 +263,31 @@ def axis_apply_fn(x):
264263

265264
@numba_funcify.register(Elemwise)
266265
def numba_funcify_Elemwise(op, node, **kwargs):
266+
nin = len(node.inputs)
267+
nout = len(node.outputs)
268+
267269
scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs]
268270
scalar_node = op.scalar_op.make_node(*scalar_inputs)
269-
270271
scalar_op_fn = numba_funcify(
271272
op.scalar_op,
272273
node=scalar_node,
273274
parent_node=node,
274275
**kwargs,
275276
)
276277

277-
nin = len(node.inputs)
278-
nout = len(node.outputs)
279-
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
278+
# TODO: Proper key
279+
key = "_".join(
280+
map(
281+
str,
282+
(
283+
op,
284+
op.scalar_op,
285+
tuple(op.inplace_pattern.items()),
286+
tuple(getattr(op.scalar_op, "props_dict", lambda: {})().items()),
287+
),
288+
)
289+
)
290+
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout, core_op_key=key)
280291

281292
input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
282293
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)
@@ -333,11 +344,18 @@ def elemwise(*inputs):
333344
return tuple(outputs_summed)
334345
return outputs_summed[0]
335346

336-
@overload(elemwise, jit_options=_jit_options)
347+
@overload(elemwise)
337348
def ov_elemwise(*inputs):
338349
return elemwise_wrapper
339350

340-
return elemwise
351+
f = compile_and_cache_numba_function_src(
352+
"def f(*inputs): return elemwise(*inputs)",
353+
"f",
354+
{**globals(), **{"elemwise": elemwise}},
355+
key=f"Elemwise_{key}",
356+
)
357+
358+
return numba_njit(f)
341359

342360

343361
@numba_funcify.register(Sum)

pytensor/link/numba/dispatch/scalar.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pytensor.compile.ops import TypeCastingOp
66
from pytensor.graph.basic import Variable
7+
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
78
from pytensor.link.numba.dispatch import basic as numba_basic
89
from pytensor.link.numba.dispatch.basic import (
910
create_numba_signature,
@@ -12,7 +13,6 @@
1213
)
1314
from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
1415
from pytensor.link.utils import (
15-
compile_function_src,
1616
get_name_for_object,
1717
unique_name_generator,
1818
)
@@ -128,16 +128,20 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
128128
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
129129
"""
130130

131-
scalar_op_fn = compile_function_src(
132-
scalar_op_src, scalar_op_fn_name, {**globals(), **global_env}
131+
scalar_op_fn = compile_and_cache_numba_function_src(
132+
scalar_op_src,
133+
scalar_op_fn_name,
134+
{**globals(), **global_env},
133135
)
134136

135-
signature = create_numba_signature(node, force_scalar=True)
137+
# signature = create_numba_signature(node, force_scalar=True)
136138

137139
return numba_basic.numba_njit(
138-
signature,
140+
# signature,
139141
# Functions that call a function pointer can't be cached
140-
cache=False,
142+
no_cfunc_wrapper=True,
143+
no_cpython_wrapper=True,
144+
register_jitable=False,
141145
)(scalar_op_fn)
142146

143147

@@ -164,7 +168,7 @@ def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op:
164168
def {binary_op_name}({input_signature}):
165169
return {output_expr}
166170
"""
167-
nary_fn = compile_function_src(nary_src, binary_op_name, globals())
171+
nary_fn = compile_and_cache_numba_function_src(nary_src, binary_op_name, globals())
168172

169173
return nary_fn
170174

0 commit comments

Comments
 (0)