Skip to content

Commit 91cb2c0

Browse files
committed
Move dot Op dispatchers to Elemwise
They are actually defined in tensor/math.py, but this is better than being in `basic.py`
1 parent 8066720 commit 91cb2c0

File tree

4 files changed

+163
-165
lines changed

4 files changed

+163
-165
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
)
2525
from pytensor.scalar.basic import ScalarType
2626
from pytensor.sparse import SparseTensorType
27-
from pytensor.tensor.blas import BatchedDot
28-
from pytensor.tensor.math import Dot
2927
from pytensor.tensor.type import TensorType
3028

3129

@@ -364,71 +362,6 @@ def inputs_cast(x):
364362
return inputs_cast
365363

366364

367-
@numba_funcify.register(Dot)
368-
def numba_funcify_Dot(op, node, **kwargs):
369-
# Numba's `np.dot` does not support integer dtypes, so we need to cast to float.
370-
x, y = node.inputs
371-
[out] = node.outputs
372-
373-
x_dtype = x.type.dtype
374-
y_dtype = y.type.dtype
375-
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
376-
out_dtype = out.type.dtype
377-
378-
if x_dtype == dot_dtype and y_dtype == dot_dtype:
379-
380-
@numba_njit
381-
def dot(x, y):
382-
return np.asarray(np.dot(x, y))
383-
384-
elif x_dtype == dot_dtype and y_dtype != dot_dtype:
385-
386-
@numba_njit
387-
def dot(x, y):
388-
return np.asarray(np.dot(x, y.astype(dot_dtype)))
389-
390-
elif x_dtype != dot_dtype and y_dtype == dot_dtype:
391-
392-
@numba_njit
393-
def dot(x, y):
394-
return np.asarray(np.dot(x.astype(dot_dtype), y))
395-
396-
else:
397-
398-
@numba_njit()
399-
def dot(x, y):
400-
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
401-
402-
if out_dtype == dot_dtype:
403-
return dot
404-
405-
else:
406-
407-
@numba_njit
408-
def dot_with_cast(x, y):
409-
return dot(x, y).astype(out_dtype)
410-
411-
return dot_with_cast
412-
413-
414-
@numba_funcify.register(BatchedDot)
415-
def numba_funcify_BatchedDot(op, node, **kwargs):
416-
dtype = node.outputs[0].type.numpy_dtype
417-
418-
@numba_njit
419-
def batched_dot(x, y):
420-
# Numba does not support 3D matmul
421-
# https://github.com/numba/numba/issues/3804
422-
shape = x.shape[:-1] + y.shape[2:]
423-
z0 = np.empty(shape, dtype=dtype)
424-
for i in range(z0.shape[0]):
425-
z0[i] = np.dot(x[i], y[i])
426-
427-
return z0
428-
429-
return batched_dot
430-
431-
432365
@numba_funcify.register(IfElse)
433366
def numba_funcify_IfElse(op, **kwargs):
434367
n_outs = op.n_outs

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
scalar_maximum,
3636
)
3737
from pytensor.scalar.basic import add as add_as
38+
from pytensor.tensor.blas import BatchedDot
3839
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
39-
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
40+
from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros, Sum
4041
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4142

4243

@@ -599,3 +600,68 @@ def argmax(x):
599600
return max_idx_res
600601

601602
return argmax
603+
604+
605+
@numba_funcify.register(Dot)
606+
def numba_funcify_Dot(op, node, **kwargs):
607+
# Numba's `np.dot` does not support integer dtypes, so we need to cast to float.
608+
x, y = node.inputs
609+
[out] = node.outputs
610+
611+
x_dtype = x.type.dtype
612+
y_dtype = y.type.dtype
613+
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
614+
out_dtype = out.type.dtype
615+
616+
if x_dtype == dot_dtype and y_dtype == dot_dtype:
617+
618+
@numba_njit
619+
def dot(x, y):
620+
return np.asarray(np.dot(x, y))
621+
622+
elif x_dtype == dot_dtype and y_dtype != dot_dtype:
623+
624+
@numba_njit
625+
def dot(x, y):
626+
return np.asarray(np.dot(x, y.astype(dot_dtype)))
627+
628+
elif x_dtype != dot_dtype and y_dtype == dot_dtype:
629+
630+
@numba_njit
631+
def dot(x, y):
632+
return np.asarray(np.dot(x.astype(dot_dtype), y))
633+
634+
else:
635+
636+
@numba_njit()
637+
def dot(x, y):
638+
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
639+
640+
if out_dtype == dot_dtype:
641+
return dot
642+
643+
else:
644+
645+
@numba_njit
646+
def dot_with_cast(x, y):
647+
return dot(x, y).astype(out_dtype)
648+
649+
return dot_with_cast
650+
651+
652+
@numba_funcify.register(BatchedDot)
653+
def numba_funcify_BatchedDot(op, node, **kwargs):
654+
dtype = node.outputs[0].type.numpy_dtype
655+
656+
@numba_njit
657+
def batched_dot(x, y):
658+
# Numba does not support 3D matmul
659+
# https://github.com/numba/numba/issues/3804
660+
shape = x.shape[:-1] + y.shape[2:]
661+
z0 = np.empty(shape, dtype=dtype)
662+
for i in range(z0.shape[0]):
663+
z0[i] = np.dot(x[i], y[i])
664+
665+
return z0
666+
667+
return batched_dot

tests/link/numba/test_basic.py

Lines changed: 0 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import pytensor.scalar as ps
1616
import pytensor.tensor as pt
17-
import pytensor.tensor.math as ptm
1817
from pytensor import config, shared
1918
from pytensor.compile.builders import OpFromGraph
2019
from pytensor.compile.function import function
@@ -29,7 +28,6 @@
2928
from pytensor.link.numba.linker import NumbaLinker
3029
from pytensor.raise_op import assert_op
3130
from pytensor.scalar.basic import ScalarOp, as_scalar
32-
from pytensor.tensor import blas, tensor
3331
from pytensor.tensor.elemwise import Elemwise
3432

3533

@@ -407,86 +405,6 @@ def test_perform_type_convert():
407405
compare_numba_and_py([x], out, [x_test_value])
408406

409407

410-
@pytest.mark.parametrize(
411-
"x, y",
412-
[
413-
(
414-
(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
415-
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
416-
),
417-
(
418-
(pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")),
419-
(pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")),
420-
),
421-
(
422-
(pt.lmatrix(), rng.poisson(size=(3, 2))),
423-
(pt.fvector(), rng.random(size=(2,)).astype("float32")),
424-
),
425-
(
426-
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
427-
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
428-
),
429-
(
430-
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
431-
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
432-
),
433-
],
434-
)
435-
def test_Dot(x, y):
436-
x, x_test_value = x
437-
y, y_test_value = y
438-
439-
g = ptm.dot(x, y)
440-
441-
compare_numba_and_py(
442-
[x, y],
443-
[g],
444-
[x_test_value, y_test_value],
445-
)
446-
447-
448-
@pytest.mark.parametrize(
449-
"x, y, exc",
450-
[
451-
(
452-
(
453-
pt.dtensor3(),
454-
rng.random(size=(2, 3, 3)).astype("float64"),
455-
),
456-
(
457-
pt.dtensor3(),
458-
rng.random(size=(2, 3, 3)).astype("float64"),
459-
),
460-
None,
461-
),
462-
(
463-
(
464-
pt.dtensor3(),
465-
rng.random(size=(2, 3, 3)).astype("float64"),
466-
),
467-
(
468-
pt.ltensor3(),
469-
rng.poisson(size=(2, 3, 3)).astype("int64"),
470-
),
471-
None,
472-
),
473-
],
474-
)
475-
def test_BatchedDot(x, y, exc):
476-
x, x_test_value = x
477-
y, y_test_value = y
478-
479-
g = blas.BatchedDot()(x, y)
480-
481-
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
482-
with cm:
483-
compare_numba_and_py(
484-
[x, y],
485-
g,
486-
[x_test_value, y_test_value],
487-
)
488-
489-
490408
def test_shared():
491409
a = shared(np.array([1, 2, 3], dtype=config.floatX))
492410

@@ -716,18 +634,3 @@ def test_function_overhead(mode, benchmark):
716634
assert np.sum(fn(test_x)) == 1000
717635

718636
benchmark(fn, test_x)
719-
720-
721-
@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
722-
def test_mat_vec_dot_performance(dtype, benchmark):
723-
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)
724-
x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype)
725-
out = ptm.dot(A, x)
726-
727-
fn = function([A, x], out, mode="NUMBA", trust_input=True)
728-
729-
rng = np.random.default_rng(948)
730-
A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype)
731-
x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype)
732-
np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4)
733-
benchmark(fn, A_test, x_test)

tests/link/numba/test_elemwise.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.compile.ops import deep_copy_op
1414
from pytensor.gradient import grad
1515
from pytensor.scalar import Composite, float64
16+
from pytensor.tensor import blas, tensor
1617
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1718
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
1819
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@@ -670,3 +671,98 @@ def test_numba_careduce_benchmark(self, axis, c_contiguous, benchmark):
670671
@pytest.mark.parametrize("c_contiguous", (True, False))
671672
def test_dimshuffle(self, c_contiguous, benchmark):
672673
dimshuffle_benchmark("NUMBA", c_contiguous, benchmark)
674+
675+
676+
@pytest.mark.parametrize(
677+
"x, y",
678+
[
679+
(
680+
(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
681+
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
682+
),
683+
(
684+
(pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")),
685+
(pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")),
686+
),
687+
(
688+
(pt.lmatrix(), rng.poisson(size=(3, 2))),
689+
(pt.fvector(), rng.random(size=(2,)).astype("float32")),
690+
),
691+
(
692+
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
693+
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
694+
),
695+
(
696+
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
697+
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
698+
),
699+
],
700+
)
701+
def test_Dot(x, y):
702+
x, x_test_value = x
703+
y, y_test_value = y
704+
705+
g = ptm.dot(x, y)
706+
707+
compare_numba_and_py(
708+
[x, y],
709+
[g],
710+
[x_test_value, y_test_value],
711+
)
712+
713+
714+
@pytest.mark.parametrize(
715+
"x, y, exc",
716+
[
717+
(
718+
(
719+
pt.dtensor3(),
720+
rng.random(size=(2, 3, 3)).astype("float64"),
721+
),
722+
(
723+
pt.dtensor3(),
724+
rng.random(size=(2, 3, 3)).astype("float64"),
725+
),
726+
None,
727+
),
728+
(
729+
(
730+
pt.dtensor3(),
731+
rng.random(size=(2, 3, 3)).astype("float64"),
732+
),
733+
(
734+
pt.ltensor3(),
735+
rng.poisson(size=(2, 3, 3)).astype("int64"),
736+
),
737+
None,
738+
),
739+
],
740+
)
741+
def test_BatchedDot(x, y, exc):
742+
x, x_test_value = x
743+
y, y_test_value = y
744+
745+
g = blas.BatchedDot()(x, y)
746+
747+
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
748+
with cm:
749+
compare_numba_and_py(
750+
[x, y],
751+
g,
752+
[x_test_value, y_test_value],
753+
)
754+
755+
756+
@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
757+
def test_mat_vec_dot_performance(dtype, benchmark):
758+
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)
759+
x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype)
760+
out = ptm.dot(A, x)
761+
762+
fn = function([A, x], out, mode="NUMBA", trust_input=True)
763+
764+
rng = np.random.default_rng(948)
765+
A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype)
766+
x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype)
767+
np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4)
768+
benchmark(fn, A_test, x_test)

0 commit comments

Comments
 (0)