Skip to content

Commit 980ae7c

Browse files
committed
Saner defaults
1 parent 271e271 commit 980ae7c

File tree

11 files changed

+314
-227
lines changed

11 files changed

+314
-227
lines changed

pytensor/link/numba/cache.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import weakref
2+
from collections.abc import Callable
3+
from functools import singledispatch, wraps
24
from hashlib import sha256
35
from pathlib import Path
6+
from pickle import dumps
7+
from tempfile import NamedTemporaryFile
8+
from typing import Any
49

510
from numba.core.caching import CacheImpl, _CacheLocator
611

712
from pytensor import config
813
from pytensor.graph.basic import Apply
14+
from pytensor.link.numba.compile import numba_funcify, numba_njit
915

1016

1117
NUMBA_PYTENSOR_CACHE_ENABLED = True
@@ -74,3 +80,158 @@ def cache_node_key(node: Apply, extra_key="") -> str:
7480
),
7581
).encode()
7682
).hexdigest()
83+
84+
85+
@singledispatch
86+
def numba_funcify_default_op_cache_key(
87+
op, node=None, **kwargs
88+
) -> Callable | tuple[Callable, Any]:
89+
"""Funcify an Op and implement a default cache key.
90+
91+
The default cache key is based on the op class and its properties.
92+
It does not take into account the node inputs or other context.
93+
Note that numba will use the array dtypes, rank and layout as part of the cache key,
94+
but not the static shape or constant values.
95+
If the funcify implementation exploits this information, then this method should not be used.
96+
Instead dispatch directly on `numba_funcify_and_cache_key` (or just numba_funcify)
97+
which won't use any cache key.
98+
"""
99+
# Default cache key of None which means "don't try to do directly cache this function"
100+
raise NotImplementedError()
101+
102+
103+
def register_funcify_default_op_cache_key(op_type):
104+
"""Register a funcify implementation for both cache and non-cache versions."""
105+
106+
def decorator(dispatch_func):
107+
# Register with the cache key dispatcher
108+
numba_funcify_default_op_cache_key.register(op_type)(dispatch_func)
109+
110+
# Create a wrapper for the non-cache dispatcher
111+
@wraps(dispatch_func)
112+
def dispatch_func_wrapper(*args, **kwargs):
113+
func, key = dispatch_func(*args, **kwargs)
114+
# Discard the key for the non-cache version
115+
return func
116+
117+
# Register the wrapper with the non-cache dispatcher
118+
numba_funcify.register(op_type)(dispatch_func_wrapper)
119+
120+
return dispatch_func
121+
122+
return decorator
123+
124+
125+
@singledispatch
126+
def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str | None]:
127+
# Default cache key of None which means "don't try to do directly cache this function"
128+
if hasattr(op, "_props"):
129+
try:
130+
func_and_salt = numba_funcify_default_op_cache_key(op, node=node, **kwargs)
131+
except NotImplementedError:
132+
pass
133+
else:
134+
if isinstance(func_and_salt, tuple):
135+
func, salt = func_and_salt
136+
else:
137+
func, salt = func_and_salt, "0"
138+
props_dict = op._props_dict()
139+
if not props_dict:
140+
# Simple op, just use the type string as key
141+
key_bytes = str((type(op), salt)).encode()
142+
else:
143+
# Simple props, can use string representation of props as key
144+
simple_types = (str, bool, int, type(None), float)
145+
container_types = (tuple, frozenset)
146+
if all(
147+
isinstance(v, simple_types)
148+
or (
149+
isinstance(v, container_types)
150+
and all(isinstance(i, simple_types) for i in v)
151+
)
152+
for v in props_dict.values()
153+
):
154+
key_bytes = str(
155+
(type(op), tuple(props_dict.items()), salt)
156+
).encode()
157+
else:
158+
# Complex props, use pickle to serialize them
159+
key_bytes = dumps((str(type(op)), tuple(props_dict.items()), salt))
160+
return func, sha256(key_bytes).hexdigest()
161+
162+
# Fallback
163+
return numba_funcify(op, node=node, **kwargs), None
164+
165+
166+
def register_funcify_and_cache_key(op_type):
167+
"""Register a funcify implementation for both cache and non-cache versions."""
168+
169+
def decorator(dispatch_func):
170+
# Register with the cache key dispatcher
171+
numba_funcify_and_cache_key.register(op_type)(dispatch_func)
172+
173+
# Create a wrapper for the non-cache dispatcher
174+
@wraps(dispatch_func)
175+
def dispatch_func_wrapper(*args, **kwargs):
176+
func, key = dispatch_func(*args, **kwargs)
177+
# Discard the key for the non-cache version
178+
return func
179+
180+
# Register the wrapper with the non-cache dispatcher
181+
numba_funcify.register(op_type)(dispatch_func_wrapper)
182+
183+
return dispatch_func_wrapper
184+
185+
return decorator
186+
187+
188+
def numba_njit_and_cache(op, node, **kwargs):
189+
jitable_func, key = numba_funcify_and_cache_key(op, node=node, **kwargs)
190+
191+
if key is not None:
192+
# To force numba to use our cache, we must compile the function so that any closure
193+
# becomes a global variable...
194+
op_name = op.__class__.__name__
195+
cached_func = compile_and_cache_numba_function_src(
196+
src=f"def {op_name}(*args): return jitable_func(*args)",
197+
function_name=op_name,
198+
global_env=globals() | dict(jitable_func=jitable_func),
199+
cache_key=key,
200+
)
201+
return numba_njit(cached_func, final_function=True, cache=True)
202+
else:
203+
return numba_njit(
204+
lambda *args: jitable_func(*args), final_function=True, cache=False
205+
)
206+
207+
208+
def compile_and_cache_numba_function_src(
209+
src: str,
210+
function_name: str,
211+
global_env: dict[Any, Any] | None = None,
212+
local_env: dict[Any, Any] | None = None,
213+
store_to_disk: bool = False,
214+
cache_key: str | None = None,
215+
) -> Callable:
216+
if store_to_disk:
217+
with NamedTemporaryFile(delete=False) as f:
218+
filename = f.name
219+
f.write(src.encode())
220+
else:
221+
filename = "<string>"
222+
223+
if global_env is None:
224+
global_env = {}
225+
226+
if local_env is None:
227+
local_env = {}
228+
229+
mod_code = compile(src, filename, mode="exec")
230+
exec(mod_code, global_env, local_env)
231+
232+
res = local_env[function_name]
233+
res.__source__ = src # type: ignore
234+
235+
if cache_key is not None:
236+
CACHED_SRC_FUNCTIONS[res] = cache_key
237+
return res

pytensor/link/numba/compile.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import warnings
22
from collections.abc import Callable
3-
from tempfile import NamedTemporaryFile
4-
from typing import Any
3+
from functools import singledispatch
54

65
import numba
76
import numpy as np
@@ -11,8 +10,6 @@
1110

1211
from pytensor import config
1312
from pytensor.graph import Apply, FunctionGraph, Type
14-
from pytensor.link.numba.cache import CACHED_SRC_FUNCTIONS
15-
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
1613
from pytensor.scalar import ScalarType
1714
from pytensor.sparse import SparseTensorType
1815
from pytensor.tensor import TensorType
@@ -59,36 +56,17 @@ def numba_njit(*args, fastmath=None, final_function: bool = False, **kwargs):
5956
return func(*args, fastmath=fastmath, **kwargs)
6057

6158

62-
def compile_and_cache_numba_function_src(
63-
src: str,
64-
function_name: str,
65-
global_env: dict[Any, Any] | None = None,
66-
local_env: dict[Any, Any] | None = None,
67-
store_to_disk: bool = False,
68-
cache_key: str | None = None,
69-
) -> Callable:
70-
if store_to_disk:
71-
with NamedTemporaryFile(delete=False) as f:
72-
filename = f.name
73-
f.write(src.encode())
74-
else:
75-
filename = "<string>"
76-
77-
if global_env is None:
78-
global_env = {}
79-
80-
if local_env is None:
81-
local_env = {}
82-
83-
mod_code = compile(src, filename, mode="exec")
84-
exec(mod_code, global_env, local_env)
59+
@singledispatch
60+
def numba_funcify(
61+
typ, node=None, storage_map=None, **kwargs
62+
) -> Callable | tuple[Callable, str | int | None]:
63+
"""Generate a numba function for a given op and apply node (or Fgraph).
8564
86-
res = local_env[function_name]
87-
res.__source__ = src # type: ignore
88-
89-
if cache_key is not None:
90-
CACHED_SRC_FUNCTIONS[res] = cache_key
91-
return res
65+
The resulting function will usually use the `no_cpython_wrapper`
66+
argument in numba, so it can not be called directly from python,
67+
but only from other jit functions.
68+
"""
69+
raise NotImplementedError(f"Numba funcify not implemented for type {typ}")
9270

9371

9472
def get_numba_type(
@@ -124,6 +102,8 @@ def get_numba_type(
124102
numba_dtype = numba.from_dtype(dtype)
125103
return numba_dtype
126104
elif isinstance(pytensor_type, SparseTensorType):
105+
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
106+
127107
dtype = pytensor_type.numpy_dtype
128108
numba_dtype = numba.from_dtype(dtype)
129109
if pytensor_type.format == "csr":

0 commit comments

Comments
 (0)