33import warnings
44from copy import copy
55from functools import singledispatch
6+ from hashlib import sha256
67from textwrap import dedent
78
89import numba
1112import scipy
1213import scipy .special
1314from llvmlite import ir
15+ from numba import njit as _njit
1416from numba import types
1517from numba .core .errors import NumbaWarning , TypingError
1618from numba .cpython .unsafe .tuple import tuple_setitem # noqa: F401
17- from numba .extending import box , overload
18- from numba .extending import register_jitable as _register_jitable
19+ from numba .extending import box , overload , register_jitable
1920
2021from pytensor import In , config
2122from pytensor .compile import NUMBA
2627from pytensor .graph .fg import FunctionGraph
2728from pytensor .graph .type import Type
2829from pytensor .ifelse import IfElse
29- from pytensor .link .numba .cache import compile_and_cache_numba_function_src
30+ from pytensor .link .numba .cache import (
31+ compile_and_cache_numba_function_src ,
32+ )
3033from pytensor .link .numba .dispatch .sparse import CSCMatrixType , CSRMatrixType
3134from pytensor .link .utils import fgraph_to_python
3235from pytensor .scalar .basic import ScalarType
@@ -50,11 +53,7 @@ def global_numba_func(func):
5053 return func
5154
5255
53- def numba_njit (* args , fastmath = None , register_jitable : bool = True , ** kwargs ):
54- kwargs .setdefault ("cache" , True )
55- kwargs .setdefault ("no_cpython_wrapper" , False )
56- kwargs .setdefault ("no_cfunc_wrapper" , False )
57-
56+ def numba_njit (* args , fastmath = None , final_function : bool = False , ** kwargs ):
5857 if fastmath is None :
5958 if config .numba__fastmath :
6059 # Opinionated default on fastmath flags
@@ -69,6 +68,12 @@ def numba_njit(*args, fastmath=None, register_jitable: bool = True, **kwargs):
6968 else :
7069 fastmath = False
7170
71+ if final_function :
72+ kwargs .setdefault ("cache" , True )
73+ # else:
74+ # kwargs.setdefault("no_cpython_wrapper", True)
75+ # kwargs.setdefault("no_cfunc_wrapper", True)
76+
7277 # Suppress cache warning for internal functions
7378 # We have to add an ansi escape code for optional bold text by numba
7479 warnings .filterwarnings (
@@ -82,7 +87,7 @@ def numba_njit(*args, fastmath=None, register_jitable: bool = True, **kwargs):
8287 category = NumbaWarning ,
8388 )
8489
85- func = _register_jitable if register_jitable else numba . njit
90+ func = register_jitable if final_function else _njit
8691 if len (args ) > 0 and callable (args [0 ]):
8792 return func (* args [1 :], fastmath = fastmath , ** kwargs )(args [0 ])
8893 else :
@@ -384,8 +389,43 @@ def numba_funcify_FunctionGraph(
384389 ** kwargs ,
385390):
386391 def numba_funcify_njit (op , node , ** kwargs ):
387- jitable_func = numba_funcify (op , node = node , ** kwargs )
388- return numba_njit (lambda * args : jitable_func (* args ), register_jitable = False )
392+ jitable_func_and_key = numba_funcify (op , node = node , ** kwargs )
393+ from collections .abc import Callable
394+
395+ match jitable_func_and_key :
396+ case (Callable (), str ()):
397+ jitable_func , key = jitable_func_and_key
398+ case (Callable (), int ()):
399+ # Default key for Ops that return an integer
400+ jitable_func , int_key = jitable_func_and_key
401+ key = sha256 (
402+ str ((type (op ), op ._props_dict (), int_key )).encode ()
403+ ).hexdigest ()
404+ case Callable ():
405+ jitable_func , key = jitable_func_and_key , None
406+ warnings .warn (
407+ f"No cache key returned by numba_funcify of op { op } . This function won't be cached by Numba"
408+ )
409+ case _:
410+ raise TypeError (
411+ f"numpy_funcify should return a callable or a callable, key pair, got { jitable_func_and_key } "
412+ )
413+
414+ if 0 and key is not None :
415+ # To force numba to use our cache, we must compile the function so that any closure
416+ # becomes a global variable...
417+ op_name = op .__class__ .__name__
418+ cached_func = compile_and_cache_numba_function_src (
419+ src = f"def { op_name } (*args): return jitable_func(*args)" ,
420+ function_name = op_name ,
421+ global_env = globals () | dict (jitable_func = jitable_func ),
422+ key = key ,
423+ )
424+ return numba_njit (cached_func , final_function = True , cache = True )
425+ else :
426+ return numba_njit (
427+ lambda * args : jitable_func (* args ), final_function = True , cache = False
428+ )
389429
390430 return fgraph_to_python (
391431 fgraph ,
@@ -410,7 +450,7 @@ def dispatch_deepcopyop(x):
410450
411451@numba_funcify .register (DeepCopyOp )
412452def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
413- return deepcopyop
453+ return deepcopyop , 0
414454
415455
416456@numba_funcify .register (MakeSlice )
@@ -439,7 +479,7 @@ def numba_funcify_Shape_i(op, **kwargs):
439479 def shape_i (x ):
440480 return np .asarray (np .shape (x )[i ])
441481
442- return shape_i
482+ return shape_i , 0
443483
444484
445485@numba_funcify .register (SortOp )
@@ -543,7 +583,7 @@ def reshape(x, shape):
543583 numba_ndarray .to_fixed_tuple (shape , ndim ),
544584 )
545585
546- return reshape
586+ return reshape , 0
547587
548588
549589@numba_funcify .register (SpecifyShape )
@@ -571,9 +611,8 @@ def specify_shape(x, {create_arg_string(shape_input_names)}):
571611 func ,
572612 "specify_shape" ,
573613 globals (),
574- key = hash_from_code (func ),
575614 )
576- return numba_njit (specify_shape )
615+ return numba_njit (specify_shape ), hash_from_code ( func )
577616
578617
579618def int_to_float_fn (inputs , out_dtype ):
0 commit comments