Skip to content

Commit 56b36ed

Browse files
committed
Cache keys for numba Op dispatches
1 parent 32a2258 commit 56b36ed

File tree

16 files changed

+482
-297
lines changed

16 files changed

+482
-297
lines changed
Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
1-
import sys
1+
from hashlib import sha256
22
from typing import cast
33

44
from numba.core.extending import overload
55
from numba.np.unsafe.ndarray import to_fixed_tuple
66

7+
from pytensor.link.numba.cache import compile_numba_function_src
78
from pytensor.link.numba.dispatch import basic as numba_basic
8-
from pytensor.link.numba.dispatch.basic import numba_funcify
9+
from pytensor.link.numba.dispatch.basic import (
10+
numba_funcify_and_cache_key,
11+
register_funcify_and_cache_key,
12+
)
913
from pytensor.link.numba.dispatch.vectorize_codegen import (
1014
_jit_options,
1115
_vectorized,
1216
encode_literals,
1317
store_core_outputs,
1418
)
15-
from pytensor.link.utils import compile_function_src
1619
from pytensor.tensor import TensorVariable, get_vector_length
1720
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
1821

1922

20-
@numba_funcify.register(BlockwiseWithCoreShape)
23+
@register_funcify_and_cache_key(BlockwiseWithCoreShape)
2124
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
2225
[blockwise_node] = op.fgraph.apply_nodes
2326
blockwise_op: Blockwise = blockwise_node.op
@@ -30,7 +33,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
3033
cast(tuple[TensorVariable], node.inputs[:nin]),
3134
propagate_unbatched_core_inputs=True,
3235
)
33-
core_op_fn = numba_funcify(
36+
core_op_fn, core_op_key = numba_funcify_and_cache_key(
3437
core_op,
3538
node=core_node,
3639
parent_node=node,
@@ -58,36 +61,56 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
5861
src += ")"
5962

6063
to_tuple = numba_basic.numba_njit(
61-
compile_function_src(
64+
compile_numba_function_src(
6265
src,
6366
"to_tuple",
6467
global_env={"to_fixed_tuple": to_fixed_tuple},
65-
),
66-
# cache=True leads to a numba.cloudpickle dump failure in Python 3.10
67-
# May be fine in Python 3.11, but I didn't test. It was fine in 3.12
68-
cache=sys.version_info >= (3, 12),
69-
)
70-
71-
def blockwise_wrapper(*inputs_and_core_shapes):
72-
inputs, core_shapes = inputs_and_core_shapes[:nin], inputs_and_core_shapes[nin:]
73-
tuple_core_shapes = to_tuple(core_shapes)
74-
return _vectorized(
75-
core_op_fn,
76-
input_bc_patterns,
77-
output_bc_patterns,
78-
output_dtypes,
79-
inplace_pattern,
80-
(), # constant_inputs
81-
inputs,
82-
tuple_core_shapes,
83-
None, # size
8468
)
69+
)
8570

8671
def blockwise(*inputs_and_core_shapes):
87-
raise NotImplementedError("Non-jitted BlockwiseWithCoreShape not implemented")
72+
raise NotImplementedError(
73+
"Numba implementation of Blockwise cannot be evaluated in Python (non-JIT) mode."
74+
)
8875

8976
@overload(blockwise, jit_options=_jit_options)
9077
def ov_blockwise(*inputs_and_core_shapes):
91-
return blockwise_wrapper
78+
def impl(*inputs_and_core_shapes):
79+
inputs, core_shapes = (
80+
inputs_and_core_shapes[:nin],
81+
inputs_and_core_shapes[nin:],
82+
)
83+
tuple_core_shapes = to_tuple(core_shapes)
84+
return _vectorized(
85+
core_op_fn,
86+
input_bc_patterns,
87+
output_bc_patterns,
88+
output_dtypes,
89+
inplace_pattern,
90+
(), # constant_inputs
91+
inputs,
92+
tuple_core_shapes,
93+
None, # size
94+
)
95+
96+
return impl
9297

93-
return blockwise
98+
if core_op_key is None:
99+
# We were told the core op cannot be cached
100+
blockwise_key = None
101+
else:
102+
blockwise_key = "_".join(
103+
map(
104+
str,
105+
(
106+
type(op),
107+
type(blockwise_op),
108+
tuple(blockwise_op.destroy_map.items()),
109+
blockwise_op.signature,
110+
input_bc_patterns,
111+
core_op_key,
112+
),
113+
)
114+
)
115+
blockwise_key = sha256(blockwise_key.encode()).hexdigest()
116+
return blockwise, blockwise_key

pytensor/link/numba/dispatch/compile_ops.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from hashlib import sha256
2+
13
import numpy as np
24

35
from pytensor.compile.builders import OpFromGraph
@@ -8,14 +10,15 @@
810
from pytensor.ifelse import IfElse
911
from pytensor.link.numba.dispatch import basic as numba_basic
1012
from pytensor.link.numba.dispatch.basic import (
11-
numba_funcify,
12-
numba_njit,
13+
numba_funcify_and_cache_key,
14+
register_funcify_and_cache_key,
15+
register_funcify_default_op_cache_key,
1316
)
1417
from pytensor.raise_op import CheckAndRaise
1518
from pytensor.tensor.type import TensorType
1619

1720

18-
@numba_funcify.register(OpFromGraph)
21+
@register_funcify_and_cache_key(OpFromGraph)
1922
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
2023
_ = kwargs.pop("storage_map", None)
2124

@@ -30,10 +33,27 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
3033
accept_inplace=True,
3134
)
3235
NUMBA.optimizer(fgraph)
33-
return numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
36+
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
37+
op.fgraph, squeeze_output=True, **kwargs
38+
)
39+
40+
if fgraph_cache_key is None:
41+
# Can't cache the inner graph
42+
ofg_cache_key = None
43+
else:
44+
ofg_cache_key = sha256(
45+
str(
46+
(
47+
type(op),
48+
fgraph_cache_key,
49+
)
50+
).encode()
51+
).hexdigest()
52+
53+
return fgraph_fn, ofg_cache_key
3454

3555

36-
@numba_funcify.register(TypeCastingOp)
56+
@register_funcify_default_op_cache_key(TypeCastingOp)
3757
def numba_funcify_type_casting(op, **kwargs):
3858
@numba_basic.numba_njit
3959
def identity(x):
@@ -42,7 +62,7 @@ def identity(x):
4262
return identity
4363

4464

45-
@numba_funcify.register(DeepCopyOp)
65+
@register_funcify_default_op_cache_key(DeepCopyOp)
4666
def numba_funcify_DeepCopyOp(op, node, **kwargs):
4767
if isinstance(node.inputs[0].type, TensorType):
4868

@@ -59,7 +79,7 @@ def deepcopy(x):
5979
return deepcopy
6080

6181

62-
@numba_funcify.register(IfElse)
82+
@register_funcify_default_op_cache_key(IfElse)
6383
def numba_funcify_IfElse(op, **kwargs):
6484
n_outs = op.n_outs
6585

@@ -88,7 +108,7 @@ def ifelse(cond, *args):
88108
return ifelse
89109

90110

91-
@numba_funcify.register(CheckAndRaise)
111+
@register_funcify_and_cache_key(CheckAndRaise)
92112
def numba_funcify_CheckAndRaise(op, node, **kwargs):
93113
error = op.exc_type
94114
msg = op.msg
@@ -100,4 +120,5 @@ def check_and_raise(x, *conditions):
100120
raise error(msg)
101121
return x
102122

103-
return check_and_raise
123+
cache_key = sha256(str((type(op), error, msg)).encode()).hexdigest()
124+
return check_and_raise, cache_key

0 commit comments

Comments
 (0)