|
1 | 1 | import weakref |
| 2 | +from collections.abc import Callable |
| 3 | +from functools import singledispatch, wraps |
2 | 4 | from hashlib import sha256 |
3 | 5 | from pathlib import Path |
| 6 | +from pickle import dumps |
| 7 | +from tempfile import NamedTemporaryFile |
| 8 | +from typing import Any |
4 | 9 |
|
5 | 10 | from numba.core.caching import CacheImpl, _CacheLocator |
6 | 11 |
|
7 | 12 | from pytensor import config |
8 | 13 | from pytensor.graph.basic import Apply |
| 14 | +from pytensor.link.numba.compile import numba_funcify, numba_njit |
9 | 15 |
|
10 | 16 |
|
11 | 17 | NUMBA_PYTENSOR_CACHE_ENABLED = True |
@@ -74,3 +80,158 @@ def cache_node_key(node: Apply, extra_key="") -> str: |
74 | 80 | ), |
75 | 81 | ).encode() |
76 | 82 | ).hexdigest() |
| 83 | + |
| 84 | + |
| 85 | +@singledispatch |
| 86 | +def numba_funcify_default_op_cache_key( |
| 87 | + op, node=None, **kwargs |
| 88 | +) -> Callable | tuple[Callable, Any]: |
| 89 | + """Funcify an Op and implement a default cache key. |
| 90 | +
|
| 91 | + The default cache key is based on the op class and its properties. |
| 92 | + It does not take into account the node inputs or other context. |
| 93 | + Note that numba will use the array dtypes, rank and layout as part of the cache key, |
| 94 | + but not the static shape or constant values. |
| 95 | + If the funcify implementation exploits this information, then this method should not be used. |
| 96 | + Instead dispatch directly on `numba_funcify_and_cache_key` (or just numba_funcify) |
| 97 | + which won't use any cache key. |
| 98 | + """ |
| 99 | + # Default cache key of None which means "don't try to do directly cache this function" |
| 100 | + raise NotImplementedError() |
| 101 | + |
| 102 | + |
| 103 | +def register_funcify_default_op_cache_key(op_type): |
| 104 | + """Register a funcify implementation for both cache and non-cache versions.""" |
| 105 | + |
| 106 | + def decorator(dispatch_func): |
| 107 | + # Register with the cache key dispatcher |
| 108 | + numba_funcify_default_op_cache_key.register(op_type)(dispatch_func) |
| 109 | + |
| 110 | + # Create a wrapper for the non-cache dispatcher |
| 111 | + @wraps(dispatch_func) |
| 112 | + def dispatch_func_wrapper(*args, **kwargs): |
| 113 | + func, key = dispatch_func(*args, **kwargs) |
| 114 | + # Discard the key for the non-cache version |
| 115 | + return func |
| 116 | + |
| 117 | + # Register the wrapper with the non-cache dispatcher |
| 118 | + numba_funcify.register(op_type)(dispatch_func_wrapper) |
| 119 | + |
| 120 | + return dispatch_func |
| 121 | + |
| 122 | + return decorator |
| 123 | + |
| 124 | + |
| 125 | +@singledispatch |
| 126 | +def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str | None]: |
| 127 | + # Default cache key of None which means "don't try to do directly cache this function" |
| 128 | + if hasattr(op, "_props"): |
| 129 | + try: |
| 130 | + func_and_salt = numba_funcify_default_op_cache_key(op, node=node, **kwargs) |
| 131 | + except NotImplementedError: |
| 132 | + pass |
| 133 | + else: |
| 134 | + if isinstance(func_and_salt, tuple): |
| 135 | + func, salt = func_and_salt |
| 136 | + else: |
| 137 | + func, salt = func_and_salt, "0" |
| 138 | + props_dict = op._props_dict() |
| 139 | + if not props_dict: |
| 140 | + # Simple op, just use the type string as key |
| 141 | + key_bytes = str((type(op), salt)).encode() |
| 142 | + else: |
| 143 | + # Simple props, can use string representation of props as key |
| 144 | + simple_types = (str, bool, int, type(None), float) |
| 145 | + container_types = (tuple, frozenset) |
| 146 | + if all( |
| 147 | + isinstance(v, simple_types) |
| 148 | + or ( |
| 149 | + isinstance(v, container_types) |
| 150 | + and all(isinstance(i, simple_types) for i in v) |
| 151 | + ) |
| 152 | + for v in props_dict.values() |
| 153 | + ): |
| 154 | + key_bytes = str( |
| 155 | + (type(op), tuple(props_dict.items()), salt) |
| 156 | + ).encode() |
| 157 | + else: |
| 158 | + # Complex props, use pickle to serialize them |
| 159 | + key_bytes = dumps((str(type(op)), tuple(props_dict.items()), salt)) |
| 160 | + return func, sha256(key_bytes).hexdigest() |
| 161 | + |
| 162 | + # Fallback |
| 163 | + return numba_funcify(op, node=node, **kwargs), None |
| 164 | + |
| 165 | + |
| 166 | +def register_funcify_and_cache_key(op_type): |
| 167 | + """Register a funcify implementation for both cache and non-cache versions.""" |
| 168 | + |
| 169 | + def decorator(dispatch_func): |
| 170 | + # Register with the cache key dispatcher |
| 171 | + numba_funcify_and_cache_key.register(op_type)(dispatch_func) |
| 172 | + |
| 173 | + # Create a wrapper for the non-cache dispatcher |
| 174 | + @wraps(dispatch_func) |
| 175 | + def dispatch_func_wrapper(*args, **kwargs): |
| 176 | + func, key = dispatch_func(*args, **kwargs) |
| 177 | + # Discard the key for the non-cache version |
| 178 | + return func |
| 179 | + |
| 180 | + # Register the wrapper with the non-cache dispatcher |
| 181 | + numba_funcify.register(op_type)(dispatch_func_wrapper) |
| 182 | + |
| 183 | + return dispatch_func_wrapper |
| 184 | + |
| 185 | + return decorator |
| 186 | + |
| 187 | + |
| 188 | +def numba_njit_and_cache(op, node, **kwargs): |
| 189 | + jitable_func, key = numba_funcify_and_cache_key(op, node=node, **kwargs) |
| 190 | + |
| 191 | + if key is not None: |
| 192 | + # To force numba to use our cache, we must compile the function so that any closure |
| 193 | + # becomes a global variable... |
| 194 | + op_name = op.__class__.__name__ |
| 195 | + cached_func = compile_and_cache_numba_function_src( |
| 196 | + src=f"def {op_name}(*args): return jitable_func(*args)", |
| 197 | + function_name=op_name, |
| 198 | + global_env=globals() | dict(jitable_func=jitable_func), |
| 199 | + cache_key=key, |
| 200 | + ) |
| 201 | + return numba_njit(cached_func, final_function=True, cache=True) |
| 202 | + else: |
| 203 | + return numba_njit( |
| 204 | + lambda *args: jitable_func(*args), final_function=True, cache=False |
| 205 | + ) |
| 206 | + |
| 207 | + |
| 208 | +def compile_and_cache_numba_function_src( |
| 209 | + src: str, |
| 210 | + function_name: str, |
| 211 | + global_env: dict[Any, Any] | None = None, |
| 212 | + local_env: dict[Any, Any] | None = None, |
| 213 | + store_to_disk: bool = False, |
| 214 | + cache_key: str | None = None, |
| 215 | +) -> Callable: |
| 216 | + if store_to_disk: |
| 217 | + with NamedTemporaryFile(delete=False) as f: |
| 218 | + filename = f.name |
| 219 | + f.write(src.encode()) |
| 220 | + else: |
| 221 | + filename = "<string>" |
| 222 | + |
| 223 | + if global_env is None: |
| 224 | + global_env = {} |
| 225 | + |
| 226 | + if local_env is None: |
| 227 | + local_env = {} |
| 228 | + |
| 229 | + mod_code = compile(src, filename, mode="exec") |
| 230 | + exec(mod_code, global_env, local_env) |
| 231 | + |
| 232 | + res = local_env[function_name] |
| 233 | + res.__source__ = src # type: ignore |
| 234 | + |
| 235 | + if cache_key is not None: |
| 236 | + CACHED_SRC_FUNCTIONS[res] = cache_key |
| 237 | + return res |
0 commit comments