Skip to content

Commit 55a1ac8

Browse files
committed
Move shape Ops dispatchers to their own file
1 parent 7ae4d33 commit 55a1ac8

File tree

5 files changed

+179
-165
lines changed

5 files changed

+179
-165
lines changed

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytensor.link.numba.dispatch.random
1010
import pytensor.link.numba.dispatch.scan
1111
import pytensor.link.numba.dispatch.scalar
12+
import pytensor.link.numba.dispatch.shape
1213
import pytensor.link.numba.dispatch.signal
1314
import pytensor.link.numba.dispatch.slinalg
1415
import pytensor.link.numba.dispatch.sparse

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import warnings
22
from copy import copy
33
from functools import singledispatch
4-
from textwrap import dedent
54

65
import numba
7-
import numba.np.unsafe.ndarray as numba_ndarray
86
import numpy as np
97
from numba import types
108
from numba.core.errors import NumbaWarning, TypingError
@@ -22,18 +20,15 @@
2220
from pytensor.ifelse import IfElse
2321
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
2422
from pytensor.link.utils import (
25-
compile_function_src,
2623
fgraph_to_python,
2724
)
2825
from pytensor.scalar.basic import ScalarType
2926
from pytensor.sparse import SparseTensorType
3027
from pytensor.tensor.basic import Nonzero
3128
from pytensor.tensor.blas import BatchedDot
3229
from pytensor.tensor.math import Dot
33-
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
3430
from pytensor.tensor.sort import ArgSortOp, SortOp
3531
from pytensor.tensor.type import TensorType
36-
from pytensor.tensor.type_other import NoneConst
3732

3833

3934
def numba_njit(*args, fastmath=None, **kwargs):
@@ -322,26 +317,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
322317
return deepcopyop
323318

324319

325-
@numba_funcify.register(Shape)
326-
def numba_funcify_Shape(op, **kwargs):
327-
@numba_njit
328-
def shape(x):
329-
return np.asarray(np.shape(x))
330-
331-
return shape
332-
333-
334-
@numba_funcify.register(Shape_i)
335-
def numba_funcify_Shape_i(op, **kwargs):
336-
i = op.i
337-
338-
@numba_njit
339-
def shape_i(x):
340-
return np.asarray(np.shape(x)[i])
341-
342-
return shape_i
343-
344-
345320
@numba_funcify.register(SortOp)
346321
def numba_funcify_SortOp(op, node, **kwargs):
347322
@numba_njit
@@ -423,54 +398,6 @@ def codegen(context, builder, signature, args):
423398
return sig, codegen
424399

425400

426-
@numba_funcify.register(Reshape)
427-
def numba_funcify_Reshape(op, **kwargs):
428-
ndim = op.ndim
429-
430-
if ndim == 0:
431-
432-
@numba_njit
433-
def reshape(x, shape):
434-
return np.asarray(x.item())
435-
436-
else:
437-
438-
@numba_njit
439-
def reshape(x, shape):
440-
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
441-
return np.reshape(
442-
np.ascontiguousarray(np.asarray(x)),
443-
numba_ndarray.to_fixed_tuple(shape, ndim),
444-
)
445-
446-
return reshape
447-
448-
449-
@numba_funcify.register(SpecifyShape)
450-
def numba_funcify_SpecifyShape(op, node, **kwargs):
451-
shape_inputs = node.inputs[1:]
452-
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
453-
454-
func_conditions = [
455-
f"assert x.shape[{i}] == {shape_input_names}"
456-
for i, (shape_input, shape_input_names) in enumerate(
457-
zip(shape_inputs, shape_input_names, strict=True)
458-
)
459-
if shape_input is not NoneConst
460-
]
461-
462-
func = dedent(
463-
f"""
464-
def specify_shape(x, {create_arg_string(shape_input_names)}):
465-
{"; ".join(func_conditions)}
466-
return x
467-
"""
468-
)
469-
470-
specify_shape = compile_function_src(func, "specify_shape", globals())
471-
return numba_njit(specify_shape)
472-
473-
474401
def int_to_float_fn(inputs, out_dtype):
475402
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
476403

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from textwrap import dedent
2+
3+
import numpy as np
4+
from numba.np.unsafe import ndarray as numba_ndarray
5+
6+
from pytensor.link.numba.dispatch import numba_funcify
7+
from pytensor.link.numba.dispatch.basic import create_arg_string, numba_njit
8+
from pytensor.link.utils import compile_function_src
9+
from pytensor.tensor import NoneConst
10+
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
11+
12+
13+
@numba_funcify.register(Shape)
14+
def numba_funcify_Shape(op, **kwargs):
15+
@numba_njit
16+
def shape(x):
17+
return np.asarray(np.shape(x))
18+
19+
return shape
20+
21+
22+
@numba_funcify.register(Shape_i)
23+
def numba_funcify_Shape_i(op, **kwargs):
24+
i = op.i
25+
26+
@numba_njit
27+
def shape_i(x):
28+
return np.asarray(np.shape(x)[i])
29+
30+
return shape_i
31+
32+
33+
@numba_funcify.register(SpecifyShape)
34+
def numba_funcify_SpecifyShape(op, node, **kwargs):
35+
shape_inputs = node.inputs[1:]
36+
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
37+
38+
func_conditions = [
39+
f"assert x.shape[{i}] == {shape_input_names}"
40+
for i, (shape_input, shape_input_names) in enumerate(
41+
zip(shape_inputs, shape_input_names, strict=True)
42+
)
43+
if shape_input is not NoneConst
44+
]
45+
46+
func = dedent(
47+
f"""
48+
def specify_shape(x, {create_arg_string(shape_input_names)}):
49+
{"; ".join(func_conditions)}
50+
return x
51+
"""
52+
)
53+
54+
specify_shape = compile_function_src(func, "specify_shape", globals())
55+
return numba_njit(specify_shape)
56+
57+
58+
@numba_funcify.register(Reshape)
59+
def numba_funcify_Reshape(op, **kwargs):
60+
ndim = op.ndim
61+
62+
if ndim == 0:
63+
64+
@numba_njit
65+
def reshape(x, shape):
66+
return np.asarray(x.item())
67+
68+
else:
69+
70+
@numba_njit
71+
def reshape(x, shape):
72+
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
73+
return np.reshape(
74+
np.ascontiguousarray(np.asarray(x)),
75+
numba_ndarray.to_fixed_tuple(shape, ndim),
76+
)
77+
78+
return reshape

tests/link/numba/test_basic.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from pytensor.scalar.basic import ScalarOp, as_scalar
3333
from pytensor.tensor import blas, tensor
3434
from pytensor.tensor.elemwise import Elemwise
35-
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
3635
from pytensor.tensor.sort import ArgSortOp, SortOp
3736

3837

@@ -357,22 +356,6 @@ def test_create_numba_signature(v, expected, force_scalar):
357356
assert res == expected
358357

359358

360-
@pytest.mark.parametrize(
361-
"x, i",
362-
[
363-
(np.zeros((20, 3)), 1),
364-
],
365-
)
366-
def test_Shape(x, i):
367-
g = Shape()(pt.as_tensor_variable(x))
368-
369-
compare_numba_and_py([], [g], [])
370-
371-
g = Shape_i(i)(pt.as_tensor_variable(x))
372-
373-
compare_numba_and_py([], [g], [])
374-
375-
376359
@pytest.mark.parametrize(
377360
"x",
378361
[
@@ -437,81 +420,6 @@ def test_ArgSort(x, axis, kind, exc):
437420
compare_numba_and_py([], [g], [])
438421

439422

440-
@pytest.mark.parametrize(
441-
"v, shape, ndim",
442-
[
443-
((pt.vector(), np.array([4], dtype=config.floatX)), ((), None), 0),
444-
((pt.vector(), np.arange(4, dtype=config.floatX)), ((2, 2), None), 2),
445-
(
446-
(pt.vector(), np.arange(4, dtype=config.floatX)),
447-
(pt.lvector(), np.array([2, 2], dtype="int64")),
448-
2,
449-
),
450-
],
451-
)
452-
def test_Reshape(v, shape, ndim):
453-
v, v_test_value = v
454-
shape, shape_test_value = shape
455-
456-
g = Reshape(ndim)(v, shape)
457-
inputs = [v] if not isinstance(shape, Variable) else [v, shape]
458-
test_values = (
459-
[v_test_value]
460-
if not isinstance(shape, Variable)
461-
else [v_test_value, shape_test_value]
462-
)
463-
compare_numba_and_py(
464-
inputs,
465-
[g],
466-
test_values,
467-
)
468-
469-
470-
def test_Reshape_scalar():
471-
v = pt.vector()
472-
v_test_value = np.array([1.0], dtype=config.floatX)
473-
g = Reshape(1)(v[0], (1,))
474-
475-
compare_numba_and_py(
476-
[v],
477-
g,
478-
[v_test_value],
479-
)
480-
481-
482-
@pytest.mark.parametrize(
483-
"v, shape, fails",
484-
[
485-
(
486-
(pt.matrix(), np.array([[1.0]], dtype=config.floatX)),
487-
(1, 1),
488-
False,
489-
),
490-
(
491-
(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
492-
(1, 1),
493-
True,
494-
),
495-
(
496-
(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
497-
(1, None),
498-
False,
499-
),
500-
],
501-
)
502-
def test_SpecifyShape(v, shape, fails):
503-
v, v_test_value = v
504-
g = SpecifyShape()(v, *shape)
505-
cm = contextlib.suppress() if not fails else pytest.raises(AssertionError)
506-
507-
with cm:
508-
compare_numba_and_py(
509-
[v],
510-
[g],
511-
[v_test_value],
512-
)
513-
514-
515423
def test_ViewOp():
516424
v = pt.vector()
517425
v_test_value = np.arange(4, dtype=config.floatX)

0 commit comments

Comments
 (0)