Skip to content

Commit f523fd4

Browse files
committed
Move shape Ops dispatchers to their own file
1 parent 0609870 commit f523fd4

File tree

5 files changed

+179
-165
lines changed

5 files changed

+179
-165
lines changed

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytensor.link.numba.dispatch.random
1010
import pytensor.link.numba.dispatch.scan
1111
import pytensor.link.numba.dispatch.scalar
12+
import pytensor.link.numba.dispatch.shape
1213
import pytensor.link.numba.dispatch.signal
1314
import pytensor.link.numba.dispatch.slinalg
1415
import pytensor.link.numba.dispatch.sparse

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import warnings
22
from copy import copy
33
from functools import singledispatch
4-
from textwrap import dedent
54

65
import numba
7-
import numba.np.unsafe.ndarray as numba_ndarray
86
import numpy as np
97
from numba import types
108
from numba.core.errors import NumbaWarning, TypingError
@@ -22,18 +20,15 @@
2220
from pytensor.ifelse import IfElse
2321
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
2422
from pytensor.link.utils import (
25-
compile_function_src,
2623
fgraph_to_python,
2724
)
2825
from pytensor.scalar.basic import ScalarType
2926
from pytensor.sparse import SparseTensorType
3027
from pytensor.tensor.basic import Nonzero
3128
from pytensor.tensor.blas import BatchedDot
3229
from pytensor.tensor.math import Dot
33-
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
3430
from pytensor.tensor.sort import ArgSortOp, SortOp
3531
from pytensor.tensor.type import TensorType
36-
from pytensor.tensor.type_other import NoneConst
3732

3833

3934
def numba_njit(*args, fastmath=None, **kwargs):
@@ -322,26 +317,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
322317
return deepcopyop
323318

324319

325-
@numba_funcify.register(Shape)
326-
def numba_funcify_Shape(op, **kwargs):
327-
@numba_njit
328-
def shape(x):
329-
return np.asarray(np.shape(x))
330-
331-
return shape
332-
333-
334-
@numba_funcify.register(Shape_i)
335-
def numba_funcify_Shape_i(op, **kwargs):
336-
i = op.i
337-
338-
@numba_njit
339-
def shape_i(x):
340-
return np.asarray(np.shape(x)[i])
341-
342-
return shape_i
343-
344-
345320
@numba_funcify.register(SortOp)
346321
def numba_funcify_SortOp(op, node, **kwargs):
347322
@numba_njit
@@ -423,54 +398,6 @@ def codegen(context, builder, signature, args):
423398
return sig, codegen
424399

425400

426-
@numba_funcify.register(Reshape)
427-
def numba_funcify_Reshape(op, **kwargs):
428-
ndim = op.ndim
429-
430-
if ndim == 0:
431-
432-
@numba_njit
433-
def reshape(x, shape):
434-
return np.asarray(x.item())
435-
436-
else:
437-
438-
@numba_njit
439-
def reshape(x, shape):
440-
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
441-
return np.reshape(
442-
np.ascontiguousarray(np.asarray(x)),
443-
numba_ndarray.to_fixed_tuple(shape, ndim),
444-
)
445-
446-
return reshape
447-
448-
449-
@numba_funcify.register(SpecifyShape)
450-
def numba_funcify_SpecifyShape(op, node, **kwargs):
451-
shape_inputs = node.inputs[1:]
452-
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
453-
454-
func_conditions = [
455-
f"assert x.shape[{i}] == {shape_input_names}"
456-
for i, (shape_input, shape_input_names) in enumerate(
457-
zip(shape_inputs, shape_input_names, strict=True)
458-
)
459-
if shape_input is not NoneConst
460-
]
461-
462-
func = dedent(
463-
f"""
464-
def specify_shape(x, {create_arg_string(shape_input_names)}):
465-
{"; ".join(func_conditions)}
466-
return x
467-
"""
468-
)
469-
470-
specify_shape = compile_function_src(func, "specify_shape", globals())
471-
return numba_njit(specify_shape)
472-
473-
474401
def int_to_float_fn(inputs, out_dtype):
475402
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
476403

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from textwrap import dedent
2+
3+
import numpy as np
4+
from numba.np.unsafe import ndarray as numba_ndarray
5+
6+
from pytensor.link.numba.dispatch import numba_funcify
7+
from pytensor.link.numba.dispatch.basic import create_arg_string, numba_njit
8+
from pytensor.link.utils import compile_function_src
9+
from pytensor.tensor import NoneConst
10+
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
11+
12+
13+
@numba_funcify.register(Shape)
14+
def numba_funcify_Shape(op, **kwargs):
15+
@numba_njit
16+
def shape(x):
17+
return np.asarray(np.shape(x))
18+
19+
return shape
20+
21+
22+
@numba_funcify.register(Shape_i)
23+
def numba_funcify_Shape_i(op, **kwargs):
24+
i = op.i
25+
26+
@numba_njit
27+
def shape_i(x):
28+
return np.asarray(np.shape(x)[i])
29+
30+
return shape_i
31+
32+
33+
@numba_funcify.register(SpecifyShape)
34+
def numba_funcify_SpecifyShape(op, node, **kwargs):
35+
shape_inputs = node.inputs[1:]
36+
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
37+
38+
func_conditions = [
39+
f"assert x.shape[{i}] == {shape_input_names}"
40+
for i, (shape_input, shape_input_names) in enumerate(
41+
zip(shape_inputs, shape_input_names, strict=True)
42+
)
43+
if shape_input is not NoneConst
44+
]
45+
46+
func = dedent(
47+
f"""
48+
def specify_shape(x, {create_arg_string(shape_input_names)}):
49+
{"; ".join(func_conditions)}
50+
return x
51+
"""
52+
)
53+
54+
specify_shape = compile_function_src(func, "specify_shape", globals())
55+
return numba_njit(specify_shape)
56+
57+
58+
@numba_funcify.register(Reshape)
59+
def numba_funcify_Reshape(op, **kwargs):
60+
ndim = op.ndim
61+
62+
if ndim == 0:
63+
64+
@numba_njit
65+
def reshape(x, shape):
66+
return np.asarray(x.item())
67+
68+
else:
69+
70+
@numba_njit
71+
def reshape(x, shape):
72+
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
73+
return np.reshape(
74+
np.ascontiguousarray(np.asarray(x)),
75+
numba_ndarray.to_fixed_tuple(shape, ndim),
76+
)
77+
78+
return reshape

tests/link/numba/test_basic.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from pytensor.scalar.basic import ScalarOp, as_scalar
3232
from pytensor.tensor import blas, tensor
3333
from pytensor.tensor.elemwise import Elemwise
34-
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
3534
from pytensor.tensor.sort import ArgSortOp, SortOp
3635

3736

@@ -332,22 +331,6 @@ def test_create_numba_signature(v, expected, force_scalar):
332331
assert res == expected
333332

334333

335-
@pytest.mark.parametrize(
336-
"x, i",
337-
[
338-
(np.zeros((20, 3)), 1),
339-
],
340-
)
341-
def test_Shape(x, i):
342-
g = Shape()(pt.as_tensor_variable(x))
343-
344-
compare_numba_and_py([], [g], [])
345-
346-
g = Shape_i(i)(pt.as_tensor_variable(x))
347-
348-
compare_numba_and_py([], [g], [])
349-
350-
351334
@pytest.mark.parametrize(
352335
"x",
353336
[
@@ -412,81 +395,6 @@ def test_ArgSort(x, axis, kind, exc):
412395
compare_numba_and_py([], [g], [])
413396

414397

415-
@pytest.mark.parametrize(
416-
"v, shape, ndim",
417-
[
418-
((pt.vector(), np.array([4], dtype=config.floatX)), ((), None), 0),
419-
((pt.vector(), np.arange(4, dtype=config.floatX)), ((2, 2), None), 2),
420-
(
421-
(pt.vector(), np.arange(4, dtype=config.floatX)),
422-
(pt.lvector(), np.array([2, 2], dtype="int64")),
423-
2,
424-
),
425-
],
426-
)
427-
def test_Reshape(v, shape, ndim):
428-
v, v_test_value = v
429-
shape, shape_test_value = shape
430-
431-
g = Reshape(ndim)(v, shape)
432-
inputs = [v] if not isinstance(shape, Variable) else [v, shape]
433-
test_values = (
434-
[v_test_value]
435-
if not isinstance(shape, Variable)
436-
else [v_test_value, shape_test_value]
437-
)
438-
compare_numba_and_py(
439-
inputs,
440-
[g],
441-
test_values,
442-
)
443-
444-
445-
def test_Reshape_scalar():
446-
v = pt.vector()
447-
v_test_value = np.array([1.0], dtype=config.floatX)
448-
g = Reshape(1)(v[0], (1,))
449-
450-
compare_numba_and_py(
451-
[v],
452-
g,
453-
[v_test_value],
454-
)
455-
456-
457-
@pytest.mark.parametrize(
458-
"v, shape, fails",
459-
[
460-
(
461-
(pt.matrix(), np.array([[1.0]], dtype=config.floatX)),
462-
(1, 1),
463-
False,
464-
),
465-
(
466-
(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
467-
(1, 1),
468-
True,
469-
),
470-
(
471-
(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
472-
(1, None),
473-
False,
474-
),
475-
],
476-
)
477-
def test_SpecifyShape(v, shape, fails):
478-
v, v_test_value = v
479-
g = SpecifyShape()(v, *shape)
480-
cm = contextlib.suppress() if not fails else pytest.raises(AssertionError)
481-
482-
with cm:
483-
compare_numba_and_py(
484-
[v],
485-
[g],
486-
[v_test_value],
487-
)
488-
489-
490398
def test_ViewOp():
491399
v = pt.vector()
492400
v_test_value = np.arange(4, dtype=config.floatX)

0 commit comments

Comments
 (0)