Skip to content

Commit 3fda09c

Browse files
committed
Move dot Op dispatchers to Elemwise
They are actually defined in tensor/math.py, but this is better than being in `basic.py`
1 parent f3fcbf6 commit 3fda09c

File tree

4 files changed

+163
-165
lines changed

4 files changed

+163
-165
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
)
2525
from pytensor.scalar.basic import ScalarType
2626
from pytensor.sparse import SparseTensorType
27-
from pytensor.tensor.blas import BatchedDot
28-
from pytensor.tensor.math import Dot
2927
from pytensor.tensor.type import TensorType
3028

3129

@@ -364,71 +362,6 @@ def inputs_cast(x):
364362
return inputs_cast
365363

366364

367-
@numba_funcify.register(Dot)
368-
def numba_funcify_Dot(op, node, **kwargs):
369-
# Numba's `np.dot` does not support integer dtypes, so we need to cast to float.
370-
x, y = node.inputs
371-
[out] = node.outputs
372-
373-
x_dtype = x.type.dtype
374-
y_dtype = y.type.dtype
375-
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
376-
out_dtype = out.type.dtype
377-
378-
if x_dtype == dot_dtype and y_dtype == dot_dtype:
379-
380-
@numba_njit
381-
def dot(x, y):
382-
return np.asarray(np.dot(x, y))
383-
384-
elif x_dtype == dot_dtype and y_dtype != dot_dtype:
385-
386-
@numba_njit
387-
def dot(x, y):
388-
return np.asarray(np.dot(x, y.astype(dot_dtype)))
389-
390-
elif x_dtype != dot_dtype and y_dtype == dot_dtype:
391-
392-
@numba_njit
393-
def dot(x, y):
394-
return np.asarray(np.dot(x.astype(dot_dtype), y))
395-
396-
else:
397-
398-
@numba_njit()
399-
def dot(x, y):
400-
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
401-
402-
if out_dtype == dot_dtype:
403-
return dot
404-
405-
else:
406-
407-
@numba_njit
408-
def dot_with_cast(x, y):
409-
return dot(x, y).astype(out_dtype)
410-
411-
return dot_with_cast
412-
413-
414-
@numba_funcify.register(BatchedDot)
415-
def numba_funcify_BatchedDot(op, node, **kwargs):
416-
dtype = node.outputs[0].type.numpy_dtype
417-
418-
@numba_njit
419-
def batched_dot(x, y):
420-
# Numba does not support 3D matmul
421-
# https://github.com/numba/numba/issues/3804
422-
shape = x.shape[:-1] + y.shape[2:]
423-
z0 = np.empty(shape, dtype=dtype)
424-
for i in range(z0.shape[0]):
425-
z0[i] = np.dot(x[i], y[i])
426-
427-
return z0
428-
429-
return batched_dot
430-
431-
432365
@numba_funcify.register(IfElse)
433366
def numba_funcify_IfElse(op, **kwargs):
434367
n_outs = op.n_outs

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
scalar_maximum,
3636
)
3737
from pytensor.scalar.basic import add as add_as
38+
from pytensor.tensor.blas import BatchedDot
3839
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
39-
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
40+
from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros, Sum
4041
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4142

4243

@@ -599,3 +600,68 @@ def argmax(x):
599600
return max_idx_res
600601

601602
return argmax
603+
604+
605+
@numba_funcify.register(Dot)
606+
def numba_funcify_Dot(op, node, **kwargs):
607+
# Numba's `np.dot` does not support integer dtypes, so we need to cast to float.
608+
x, y = node.inputs
609+
[out] = node.outputs
610+
611+
x_dtype = x.type.dtype
612+
y_dtype = y.type.dtype
613+
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
614+
out_dtype = out.type.dtype
615+
616+
if x_dtype == dot_dtype and y_dtype == dot_dtype:
617+
618+
@numba_njit
619+
def dot(x, y):
620+
return np.asarray(np.dot(x, y))
621+
622+
elif x_dtype == dot_dtype and y_dtype != dot_dtype:
623+
624+
@numba_njit
625+
def dot(x, y):
626+
return np.asarray(np.dot(x, y.astype(dot_dtype)))
627+
628+
elif x_dtype != dot_dtype and y_dtype == dot_dtype:
629+
630+
@numba_njit
631+
def dot(x, y):
632+
return np.asarray(np.dot(x.astype(dot_dtype), y))
633+
634+
else:
635+
636+
@numba_njit()
637+
def dot(x, y):
638+
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
639+
640+
if out_dtype == dot_dtype:
641+
return dot
642+
643+
else:
644+
645+
@numba_njit
646+
def dot_with_cast(x, y):
647+
return dot(x, y).astype(out_dtype)
648+
649+
return dot_with_cast
650+
651+
652+
@numba_funcify.register(BatchedDot)
653+
def numba_funcify_BatchedDot(op, node, **kwargs):
654+
dtype = node.outputs[0].type.numpy_dtype
655+
656+
@numba_njit
657+
def batched_dot(x, y):
658+
# Numba does not support 3D matmul
659+
# https://github.com/numba/numba/issues/3804
660+
shape = x.shape[:-1] + y.shape[2:]
661+
z0 = np.empty(shape, dtype=dtype)
662+
for i in range(z0.shape[0]):
663+
z0[i] = np.dot(x[i], y[i])
664+
665+
return z0
666+
667+
return batched_dot

tests/link/numba/test_basic.py

Lines changed: 0 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import pytensor.scalar as ps
1717
import pytensor.tensor as pt
18-
import pytensor.tensor.math as ptm
1918
from pytensor import config, shared
2019
from pytensor.compile.builders import OpFromGraph
2120
from pytensor.compile.function import function
@@ -30,7 +29,6 @@
3029
from pytensor.link.numba.linker import NumbaLinker
3130
from pytensor.raise_op import assert_op
3231
from pytensor.scalar.basic import ScalarOp, as_scalar
33-
from pytensor.tensor import blas, tensor
3432
from pytensor.tensor.elemwise import Elemwise
3533

3634

@@ -432,86 +430,6 @@ def test_perform_type_convert():
432430
compare_numba_and_py([x], out, [x_test_value])
433431

434432

435-
@pytest.mark.parametrize(
436-
"x, y",
437-
[
438-
(
439-
(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
440-
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
441-
),
442-
(
443-
(pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")),
444-
(pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")),
445-
),
446-
(
447-
(pt.lmatrix(), rng.poisson(size=(3, 2))),
448-
(pt.fvector(), rng.random(size=(2,)).astype("float32")),
449-
),
450-
(
451-
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
452-
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
453-
),
454-
(
455-
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
456-
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
457-
),
458-
],
459-
)
460-
def test_Dot(x, y):
461-
x, x_test_value = x
462-
y, y_test_value = y
463-
464-
g = ptm.dot(x, y)
465-
466-
compare_numba_and_py(
467-
[x, y],
468-
[g],
469-
[x_test_value, y_test_value],
470-
)
471-
472-
473-
@pytest.mark.parametrize(
474-
"x, y, exc",
475-
[
476-
(
477-
(
478-
pt.dtensor3(),
479-
rng.random(size=(2, 3, 3)).astype("float64"),
480-
),
481-
(
482-
pt.dtensor3(),
483-
rng.random(size=(2, 3, 3)).astype("float64"),
484-
),
485-
None,
486-
),
487-
(
488-
(
489-
pt.dtensor3(),
490-
rng.random(size=(2, 3, 3)).astype("float64"),
491-
),
492-
(
493-
pt.ltensor3(),
494-
rng.poisson(size=(2, 3, 3)).astype("int64"),
495-
),
496-
None,
497-
),
498-
],
499-
)
500-
def test_BatchedDot(x, y, exc):
501-
x, x_test_value = x
502-
y, y_test_value = y
503-
504-
g = blas.BatchedDot()(x, y)
505-
506-
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
507-
with cm:
508-
compare_numba_and_py(
509-
[x, y],
510-
g,
511-
[x_test_value, y_test_value],
512-
)
513-
514-
515433
def test_shared():
516434
a = shared(np.array([1, 2, 3], dtype=config.floatX))
517435

@@ -751,18 +669,3 @@ def test_function_overhead(mode, benchmark):
751669
assert np.sum(fn(test_x)) == 1000
752670

753671
benchmark(fn, test_x)
754-
755-
756-
@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
757-
def test_mat_vec_dot_performance(dtype, benchmark):
758-
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)
759-
x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype)
760-
out = ptm.dot(A, x)
761-
762-
fn = function([A, x], out, mode="NUMBA", trust_input=True)
763-
764-
rng = np.random.default_rng(948)
765-
A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype)
766-
x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype)
767-
np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4)
768-
benchmark(fn, A_test, x_test)

tests/link/numba/test_elemwise.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.compile.ops import deep_copy_op
1414
from pytensor.gradient import grad
1515
from pytensor.scalar import Composite, float64
16+
from pytensor.tensor import blas, tensor
1617
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1718
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
1819
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@@ -670,3 +671,98 @@ def test_numba_careduce_benchmark(self, axis, c_contiguous, benchmark):
670671
@pytest.mark.parametrize("c_contiguous", (True, False))
671672
def test_dimshuffle(self, c_contiguous, benchmark):
672673
dimshuffle_benchmark("NUMBA", c_contiguous, benchmark)
674+
675+
676+
@pytest.mark.parametrize(
677+
"x, y",
678+
[
679+
(
680+
(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
681+
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
682+
),
683+
(
684+
(pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")),
685+
(pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")),
686+
),
687+
(
688+
(pt.lmatrix(), rng.poisson(size=(3, 2))),
689+
(pt.fvector(), rng.random(size=(2,)).astype("float32")),
690+
),
691+
(
692+
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
693+
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
694+
),
695+
(
696+
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
697+
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
698+
),
699+
],
700+
)
701+
def test_Dot(x, y):
702+
x, x_test_value = x
703+
y, y_test_value = y
704+
705+
g = ptm.dot(x, y)
706+
707+
compare_numba_and_py(
708+
[x, y],
709+
[g],
710+
[x_test_value, y_test_value],
711+
)
712+
713+
714+
@pytest.mark.parametrize(
715+
"x, y, exc",
716+
[
717+
(
718+
(
719+
pt.dtensor3(),
720+
rng.random(size=(2, 3, 3)).astype("float64"),
721+
),
722+
(
723+
pt.dtensor3(),
724+
rng.random(size=(2, 3, 3)).astype("float64"),
725+
),
726+
None,
727+
),
728+
(
729+
(
730+
pt.dtensor3(),
731+
rng.random(size=(2, 3, 3)).astype("float64"),
732+
),
733+
(
734+
pt.ltensor3(),
735+
rng.poisson(size=(2, 3, 3)).astype("int64"),
736+
),
737+
None,
738+
),
739+
],
740+
)
741+
def test_BatchedDot(x, y, exc):
742+
x, x_test_value = x
743+
y, y_test_value = y
744+
745+
g = blas.BatchedDot()(x, y)
746+
747+
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
748+
with cm:
749+
compare_numba_and_py(
750+
[x, y],
751+
g,
752+
[x_test_value, y_test_value],
753+
)
754+
755+
756+
@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
757+
def test_mat_vec_dot_performance(dtype, benchmark):
758+
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)
759+
x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype)
760+
out = ptm.dot(A, x)
761+
762+
fn = function([A, x], out, mode="NUMBA", trust_input=True)
763+
764+
rng = np.random.default_rng(948)
765+
A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype)
766+
x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype)
767+
np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4)
768+
benchmark(fn, A_test, x_test)

0 commit comments

Comments
 (0)