Skip to content

Commit f8d4ea6

Browse files
committed
Move compile ops to their own dispatch file
1 parent 02caad5 commit f8d4ea6

File tree

7 files changed

+274
-260
lines changed

7 files changed

+274
-260
lines changed

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# Load dispatch specializations
55
import pytensor.link.numba.dispatch.blockwise
6+
import pytensor.link.numba.dispatch.compile_ops
67
import pytensor.link.numba.dispatch.elemwise
78
import pytensor.link.numba.dispatch.extra_ops
89
import pytensor.link.numba.dispatch.nlinalg

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,10 @@
66
from numba.core.errors import NumbaWarning
77
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
88

9-
from pytensor import In, config
10-
from pytensor.compile import NUMBA
11-
from pytensor.compile.builders import OpFromGraph
12-
from pytensor.compile.function.types import add_supervisor_to_fgraph
13-
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
9+
from pytensor import config
1410
from pytensor.graph.basic import Apply
1511
from pytensor.graph.fg import FunctionGraph
1612
from pytensor.graph.type import Type
17-
from pytensor.ifelse import IfElse
1813
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
1914
from pytensor.link.utils import (
2015
fgraph_to_python,
@@ -280,90 +275,3 @@ def numba_funcify_FunctionGraph(
280275
fgraph_name=fgraph_name,
281276
**kwargs,
282277
)
283-
284-
285-
@numba_funcify.register(OpFromGraph)
286-
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
287-
_ = kwargs.pop("storage_map", None)
288-
289-
# Apply inner rewrites
290-
# TODO: Not sure this is the right place to do this, should we have a rewrite that
291-
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
292-
# The C-code defers it to the make_thunk phase
293-
fgraph = op.fgraph
294-
add_supervisor_to_fgraph(
295-
fgraph=fgraph,
296-
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
297-
accept_inplace=True,
298-
)
299-
NUMBA.optimizer(fgraph)
300-
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
301-
302-
if len(op.fgraph.outputs) == 1:
303-
304-
@numba_njit
305-
def opfromgraph(*inputs):
306-
return fgraph_fn(*inputs)[0]
307-
308-
else:
309-
310-
@numba_njit
311-
def opfromgraph(*inputs):
312-
return fgraph_fn(*inputs)
313-
314-
return opfromgraph
315-
316-
317-
@numba_funcify.register(TypeCastingOp)
318-
def numba_funcify_type_casting(op, **kwargs):
319-
@numba_njit
320-
def identity(x):
321-
return x
322-
323-
return identity
324-
325-
326-
@numba_funcify.register(DeepCopyOp)
327-
def numba_funcify_DeepCopyOp(op, node, **kwargs):
328-
if isinstance(node.inputs[0].type, TensorType):
329-
330-
@numba_njit
331-
def deepcopy(x):
332-
return np.copy(x)
333-
334-
else:
335-
336-
@numba_njit
337-
def deepcopy(x):
338-
return x
339-
340-
return deepcopy
341-
342-
343-
@numba_funcify.register(IfElse)
344-
def numba_funcify_IfElse(op, **kwargs):
345-
n_outs = op.n_outs
346-
347-
if n_outs > 1:
348-
349-
@numba_njit
350-
def ifelse(cond, *args):
351-
if cond:
352-
res = args[:n_outs]
353-
else:
354-
res = args[n_outs:]
355-
356-
return res
357-
358-
else:
359-
360-
@numba_njit
361-
def ifelse(cond, *args):
362-
if cond:
363-
res = args[:n_outs]
364-
else:
365-
res = args[n_outs:]
366-
367-
return res[0]
368-
369-
return ifelse
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import numpy as np
2+
3+
from pytensor.compile.builders import OpFromGraph
4+
from pytensor.compile.function.types import add_supervisor_to_fgraph
5+
from pytensor.compile.io import In
6+
from pytensor.compile.mode import NUMBA
7+
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
8+
from pytensor.ifelse import IfElse
9+
from pytensor.link.numba.dispatch import basic as numba_basic
10+
from pytensor.link.numba.dispatch.basic import (
11+
numba_funcify,
12+
numba_njit,
13+
)
14+
from pytensor.raise_op import CheckAndRaise
15+
from pytensor.tensor.type import TensorType
16+
17+
18+
@numba_funcify.register(OpFromGraph)
19+
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
20+
_ = kwargs.pop("storage_map", None)
21+
22+
# Apply inner rewrites
23+
# TODO: Not sure this is the right place to do this, should we have a rewrite that
24+
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
25+
# The C-code defers it to the make_thunk phase
26+
fgraph = op.fgraph
27+
add_supervisor_to_fgraph(
28+
fgraph=fgraph,
29+
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
30+
accept_inplace=True,
31+
)
32+
NUMBA.optimizer(fgraph)
33+
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
34+
35+
if len(op.fgraph.outputs) == 1:
36+
37+
@numba_basic.numba_njit
38+
def opfromgraph(*inputs):
39+
return fgraph_fn(*inputs)[0]
40+
41+
else:
42+
43+
@numba_basic.numba_njit
44+
def opfromgraph(*inputs):
45+
return fgraph_fn(*inputs)
46+
47+
return opfromgraph
48+
49+
50+
@numba_funcify.register(TypeCastingOp)
51+
def numba_funcify_type_casting(op, **kwargs):
52+
@numba_basic.numba_njit
53+
def identity(x):
54+
return x
55+
56+
return identity
57+
58+
59+
@numba_funcify.register(DeepCopyOp)
60+
def numba_funcify_DeepCopyOp(op, node, **kwargs):
61+
if isinstance(node.inputs[0].type, TensorType):
62+
63+
@numba_basic.numba_njit
64+
def deepcopy(x):
65+
return np.copy(x)
66+
67+
else:
68+
69+
@numba_basic.numba_njit
70+
def deepcopy(x):
71+
return x
72+
73+
return deepcopy
74+
75+
76+
@numba_funcify.register(IfElse)
77+
def numba_funcify_IfElse(op, **kwargs):
78+
n_outs = op.n_outs
79+
80+
if n_outs > 1:
81+
82+
@numba_basic.numba_njit
83+
def ifelse(cond, *args):
84+
if cond:
85+
res = args[:n_outs]
86+
else:
87+
res = args[n_outs:]
88+
89+
return res
90+
91+
else:
92+
93+
@numba_basic.numba_njit
94+
def ifelse(cond, *args):
95+
if cond:
96+
res = args[:n_outs]
97+
else:
98+
res = args[n_outs:]
99+
100+
return res[0]
101+
102+
return ifelse
103+
104+
105+
@numba_funcify.register(CheckAndRaise)
106+
def numba_funcify_CheckAndRaise(op, node, **kwargs):
107+
error = op.exc_type
108+
msg = op.msg
109+
110+
@numba_basic.numba_njit
111+
def check_and_raise(x, *conditions):
112+
for cond in conditions:
113+
if not cond:
114+
raise error(msg)
115+
return x
116+
117+
return check_and_raise

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
get_numba_type,
1212
numba_funcify,
1313
)
14-
from pytensor.raise_op import CheckAndRaise
1514
from pytensor.tensor import TensorVariable
1615
from pytensor.tensor.extra_ops import (
1716
Bartlett,
@@ -325,18 +324,3 @@ def searchsorted(a, v):
325324
return np.searchsorted(a, v, side)
326325

327326
return searchsorted
328-
329-
330-
@numba_funcify.register(CheckAndRaise)
331-
def numba_funcify_CheckAndRaise(op, node, **kwargs):
332-
error = op.exc_type
333-
msg = op.msg
334-
335-
@numba_basic.numba_njit
336-
def check_and_raise(x, *conditions):
337-
for cond in conditions:
338-
if not cond:
339-
raise error(msg)
340-
return x
341-
342-
return check_and_raise

0 commit comments

Comments
 (0)