Skip to content

Commit 9c1ee06

Browse files
committed
.More hacking around
1 parent 4ef29b1 commit 9c1ee06

File tree

10 files changed

+107
-95
lines changed

10 files changed

+107
-95
lines changed

pytensor/link/numba/cache.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Callable
22
from pathlib import Path
3-
from tempfile import NamedTemporaryFile, TemporaryFile
3+
from tempfile import NamedTemporaryFile
44
from typing import Any
55

66
from numba.core.caching import CacheImpl, _CacheLocator
@@ -9,7 +9,9 @@
99

1010

1111
NUMBA_PYTENSOR_CACHE_ENABLED = True
12-
COMPILED_SRC_FUNCTIONS = {}
12+
NUMBA_CACHE_PATH = config.base_compiledir / "numba"
13+
NUMBA_CACHE_PATH.mkdir(exist_ok=True)
14+
CACHED_SRC_FUNCTIONS = {}
1315

1416

1517
def compile_and_cache_numba_function_src(
@@ -20,9 +22,7 @@ def compile_and_cache_numba_function_src(
2022
key: str | None = None,
2123
) -> Callable:
2224
if key is not None:
23-
numba_path = config.base_compiledir / "numba"
24-
numba_path.mkdir(exist_ok=True)
25-
filename = numba_path / key
25+
filename = NUMBA_CACHE_PATH / key
2626
with filename.open("wb") as f:
2727
f.write(src.encode())
2828
else:
@@ -43,10 +43,19 @@ def compile_and_cache_numba_function_src(
4343
res.__source__ = src # type: ignore
4444

4545
if key is not None:
46-
COMPILED_SRC_FUNCTIONS[res] = key
46+
CACHED_SRC_FUNCTIONS[res] = key
4747
return res
4848

4949

50+
def cache_numba_function(
51+
fn,
52+
key: str | None = None,
53+
) -> Callable:
54+
if key is not None:
55+
CACHED_SRC_FUNCTIONS[fn] = key
56+
return fn
57+
58+
5059
class NumbaPyTensorCacheLocator(_CacheLocator):
5160
def __init__(self, py_func, py_file, hash):
5261
# print(f"New locator {py_func=}, {py_file=}, {hash=}")
@@ -57,34 +66,26 @@ def __init__(self, py_func, py_file, hash):
5766
# self._hash = hash((src_hash, py_file, pytensor.__version__))
5867

5968
def ensure_cache_path(self):
60-
# print("ensure_cache_path called")
61-
path = self.get_cache_path()
62-
path.mkdir(exist_ok=True)
63-
# Ensure the directory is writable by trying to write a temporary file
64-
TemporaryFile(dir=path).close()
69+
pass
6570

6671
def get_cache_path(self):
6772
"""
6873
Return the directory the function is cached in.
6974
"""
70-
# print("get_cache_path called")
71-
return self._py_file
75+
return NUMBA_CACHE_PATH
7276

7377
def get_source_stamp(self):
7478
"""
7579
Get a timestamp representing the source code's freshness.
7680
Can return any picklable Python object.
7781
"""
7882
return 0
79-
# print("get_source_stamp called")
80-
return self._hash
8183

8284
def get_disambiguator(self):
8385
"""
8486
Get a string disambiguator for this locator's function.
8587
It should allow disambiguating different but similarly-named functions.
8688
"""
87-
# print("get_disambiguator called")
8889
return self._hash
8990

9091
@classmethod
@@ -94,9 +95,9 @@ def from_function(cls, py_func, py_file):
9495
"""
9596
# py_file = Path(py_file).parent
9697
# if py_file == (config.base_compiledir / "numba"):
97-
if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in COMPILED_SRC_FUNCTIONS:
98+
if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in CACHED_SRC_FUNCTIONS:
9899
# print(f"Applies to {py_file}")
99-
return cls(py_func, Path(py_file).parent, COMPILED_SRC_FUNCTIONS[py_func])
100+
return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func])
100101

101102

102103
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)

pytensor/link/numba/dispatch/basic.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44
from copy import copy
55
from functools import singledispatch
6+
from hashlib import sha256
67
from textwrap import dedent
78

89
import numba
@@ -11,11 +12,11 @@
1112
import scipy
1213
import scipy.special
1314
from llvmlite import ir
15+
from numba import njit as _njit
1416
from numba import types
1517
from numba.core.errors import NumbaWarning, TypingError
1618
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
17-
from numba.extending import box, overload
18-
from numba.extending import register_jitable as _register_jitable
19+
from numba.extending import box, overload, register_jitable
1920

2021
from pytensor import In, config
2122
from pytensor.compile import NUMBA
@@ -26,7 +27,9 @@
2627
from pytensor.graph.fg import FunctionGraph
2728
from pytensor.graph.type import Type
2829
from pytensor.ifelse import IfElse
29-
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
30+
from pytensor.link.numba.cache import (
31+
compile_and_cache_numba_function_src,
32+
)
3033
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
3134
from pytensor.link.utils import fgraph_to_python
3235
from pytensor.scalar.basic import ScalarType
@@ -50,11 +53,7 @@ def global_numba_func(func):
5053
return func
5154

5255

53-
def numba_njit(*args, fastmath=None, register_jitable: bool = True, **kwargs):
54-
kwargs.setdefault("cache", True)
55-
kwargs.setdefault("no_cpython_wrapper", False)
56-
kwargs.setdefault("no_cfunc_wrapper", False)
57-
56+
def numba_njit(*args, fastmath=None, final_function: bool = False, **kwargs):
5857
if fastmath is None:
5958
if config.numba__fastmath:
6059
# Opinionated default on fastmath flags
@@ -69,6 +68,12 @@ def numba_njit(*args, fastmath=None, register_jitable: bool = True, **kwargs):
6968
else:
7069
fastmath = False
7170

71+
if final_function:
72+
kwargs.setdefault("cache", True)
73+
# else:
74+
# kwargs.setdefault("no_cpython_wrapper", True)
75+
# kwargs.setdefault("no_cfunc_wrapper", True)
76+
7277
# Suppress cache warning for internal functions
7378
# We have to add an ansi escape code for optional bold text by numba
7479
warnings.filterwarnings(
@@ -82,7 +87,7 @@ def numba_njit(*args, fastmath=None, register_jitable: bool = True, **kwargs):
8287
category=NumbaWarning,
8388
)
8489

85-
func = _register_jitable if register_jitable else numba.njit
90+
func = register_jitable if final_function else _njit
8691
if len(args) > 0 and callable(args[0]):
8792
return func(*args[1:], fastmath=fastmath, **kwargs)(args[0])
8893
else:
@@ -384,8 +389,43 @@ def numba_funcify_FunctionGraph(
384389
**kwargs,
385390
):
386391
def numba_funcify_njit(op, node, **kwargs):
387-
jitable_func = numba_funcify(op, node=node, **kwargs)
388-
return numba_njit(lambda *args: jitable_func(*args), register_jitable=False)
392+
jitable_func_and_key = numba_funcify(op, node=node, **kwargs)
393+
from collections.abc import Callable
394+
395+
match jitable_func_and_key:
396+
case (Callable(), str()):
397+
jitable_func, key = jitable_func_and_key
398+
case (Callable(), int()):
399+
# Default key for Ops that return an integer
400+
jitable_func, int_key = jitable_func_and_key
401+
key = sha256(
402+
str((type(op), op._props_dict(), int_key)).encode()
403+
).hexdigest()
404+
case Callable():
405+
jitable_func, key = jitable_func_and_key, None
406+
warnings.warn(
407+
f"No cache key returned by numba_funcify of op {op}. This function won't be cached by Numba"
408+
)
409+
case _:
410+
raise TypeError(
411+
f"numpy_funcify should return a callable or a callable, key pair, got {jitable_func_and_key}"
412+
)
413+
414+
if 0 and key is not None:
415+
# To force numba to use our cache, we must compile the function so that any closure
416+
# becomes a global variable...
417+
op_name = op.__class__.__name__
418+
cached_func = compile_and_cache_numba_function_src(
419+
src=f"def {op_name}(*args): return jitable_func(*args)",
420+
function_name=op_name,
421+
global_env=globals() | dict(jitable_func=jitable_func),
422+
key=key,
423+
)
424+
return numba_njit(cached_func, final_function=True, cache=True)
425+
else:
426+
return numba_njit(
427+
lambda *args: jitable_func(*args), final_function=True, cache=False
428+
)
389429

390430
return fgraph_to_python(
391431
fgraph,
@@ -410,7 +450,7 @@ def dispatch_deepcopyop(x):
410450

411451
@numba_funcify.register(DeepCopyOp)
412452
def numba_funcify_DeepCopyOp(op, node, **kwargs):
413-
return deepcopyop
453+
return deepcopyop, 0
414454

415455

416456
@numba_funcify.register(MakeSlice)
@@ -439,7 +479,7 @@ def numba_funcify_Shape_i(op, **kwargs):
439479
def shape_i(x):
440480
return np.asarray(np.shape(x)[i])
441481

442-
return shape_i
482+
return shape_i, 0
443483

444484

445485
@numba_funcify.register(SortOp)
@@ -543,7 +583,7 @@ def reshape(x, shape):
543583
numba_ndarray.to_fixed_tuple(shape, ndim),
544584
)
545585

546-
return reshape
586+
return reshape, 0
547587

548588

549589
@numba_funcify.register(SpecifyShape)
@@ -571,9 +611,8 @@ def specify_shape(x, {create_arg_string(shape_input_names)}):
571611
func,
572612
"specify_shape",
573613
globals(),
574-
key=hash_from_code(func),
575614
)
576-
return numba_njit(specify_shape)
615+
return numba_njit(specify_shape), hash_from_code(func)
577616

578617

579618
def int_to_float_fn(inputs, out_dtype):

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
288288
),
289289
)
290290
)
291-
core_op_key = sha256(core_op_key.encode()).hexdigest()
292-
core_op_fn = store_core_outputs(
293-
scalar_op_fn, nin=nin, nout=nout, core_op_key=core_op_key
294-
)
291+
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
295292

296293
input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
297294
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)
@@ -342,27 +339,8 @@ def elemwise(*inputs):
342339
def ov_elemwise(*inputs):
343340
return elemwise_wrapper
344341

345-
# TODO: Also input dtypes in key
346-
elemwise_key = "_".join(
347-
map(
348-
str,
349-
(
350-
"Elemwise",
351-
core_op_key,
352-
input_bc_patterns,
353-
inplace_pattern,
354-
),
355-
)
356-
)
357-
elemwise_key = sha256(elemwise_key.encode()).hexdigest()
358-
f = compile_and_cache_numba_function_src(
359-
"def f(*inputs): return elemwise(*inputs)",
360-
"f",
361-
{**globals(), **{"elemwise": elemwise}},
362-
key=elemwise_key,
363-
)
364-
365-
return numba_njit(f)
342+
elemwise_key = sha256(f"Elemwise2{core_op_key}".encode()).hexdigest()
343+
return elemwise, elemwise_key
366344

367345

368346
@numba_funcify.register(Sum)
@@ -470,7 +448,7 @@ def dimshuffle(x):
470448

471449
return as_strided(x, shape=new_shape, strides=new_strides)
472450

473-
return dimshuffle
451+
return dimshuffle, 0
474452

475453

476454
@numba_funcify.register(Softmax)

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,4 +367,4 @@ def check_and_raise(x, *conditions):
367367
raise error(msg)
368368
return x
369369

370-
return check_and_raise
370+
return check_and_raise, 0

pytensor/link/numba/dispatch/scalar.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,7 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
136136

137137
# signature = create_numba_signature(node, force_scalar=True)
138138

139-
return numba_basic.numba_njit(
140-
# signature,
141-
# Functions that call a function pointer can't be cached
142-
no_cfunc_wrapper=True,
143-
no_cpython_wrapper=True,
144-
register_jitable=False,
145-
)(scalar_op_fn)
139+
return numba_basic.numba_njit(scalar_op_fn)
146140

147141

148142
@numba_funcify.register(Switch)

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,8 @@ def {function_name}({", ".join(input_names)}):
101101
subtensor_def_src,
102102
function_name=function_name,
103103
global_env=globals() | {"np": np},
104-
key=hash_from_code(subtensor_def_src),
105104
)
106-
return numba_njit(func, boundscheck=True)
105+
return numba_njit(func, boundscheck=True), hash_from_code(subtensor_def_src)
107106

108107

109108
@numba_funcify.register(AdvancedSubtensor)
@@ -350,7 +349,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
350349
return x
351350

352351
if inplace:
353-
return advancedincsubtensor1_inplace
352+
return advancedincsubtensor1_inplace, 0
354353

355354
else:
356355

@@ -359,4 +358,4 @@ def advancedincsubtensor1(x, vals, idxs):
359358
x = x.copy()
360359
return advancedincsubtensor1_inplace(x, vals, idxs)
361360

362-
return advancedincsubtensor1
361+
return advancedincsubtensor1, 0

0 commit comments

Comments
 (0)