Skip to content

Commit 6cf5842

Browse files
committed
Gemm optimizer spends too much time creating constants of the wrong type and then casting them
1 parent da983bc commit 6cf5842

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

pytensor/tensor/rewriting/blas.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,13 @@
8383
)
8484
from pytensor.graph.rewriting.db import SequenceDB
8585
from pytensor.graph.utils import InconsistencyError
86-
from pytensor.tensor import basic as ptb
86+
from pytensor.tensor import as_tensor_variable
87+
from pytensor.tensor.basic import (
88+
AllocEmpty,
89+
cast,
90+
get_underlying_scalar_constant_value,
91+
zeros,
92+
)
8793
from pytensor.tensor.blas import (
8894
Dot22,
8995
_batched_dot,
@@ -143,7 +149,7 @@ def _as_scalar(res, dtype=None):
143149
# as the cast of the scalar can be done before or after the dot22
144150
# and this will give the same result.
145151
if pytensor.scalar.upcast(res.dtype, dtype) == dtype:
146-
return ptb.cast(rval, dtype)
152+
return cast(rval, dtype)
147153
else:
148154
return None
149155

@@ -358,13 +364,13 @@ def _gemm_from_factored_list(fgraph, lst):
358364
# sM can be a tuple of 2 elements or an PyTensor variable.
359365
if isinstance(sM, tuple):
360366
sm0, sm1 = sM
361-
sm0 = ptb.as_tensor_variable(sm0)
362-
sm0_dtype = sm0.type.dtype
363367
sm1_dtype = sm1.type.dtype
368+
sm0 = as_tensor_variable(sm0, dtype=sm1_dtype)
369+
sm0_dtype = sm0.type.dtype
364370
if sm0_dtype == sm1_dtype:
365371
lst2.append((sm0, sm1))
366372
elif upcast(sm0_dtype, sm1_dtype) == sm1_dtype:
367-
lst2.append((ptb.cast(sm0, sm1_dtype), sm1))
373+
lst2.append((cast(sm0, sm1_dtype), sm1))
368374

369375
lst = lst2
370376

@@ -654,7 +660,7 @@ def local_gemm_to_ger(fgraph, node):
654660
xv = x.dimshuffle(0)
655661
yv = y.dimshuffle(1)
656662
try:
657-
bval = ptb.get_underlying_scalar_constant_value(b)
663+
bval = get_underlying_scalar_constant_value(b)
658664
except NotScalarConstantError:
659665
# b isn't a constant, GEMM is doing useful pre-scaling
660666
return
@@ -663,8 +669,7 @@ def local_gemm_to_ger(fgraph, node):
663669
rval = ger(z, a, xv, yv)
664670
new_out = [rval]
665671
elif bval == 0: # GER on zeros_like should be faster than GEMM
666-
zeros = ptb.zeros([x.shape[0], y.shape[1]], x.dtype)
667-
rval = ger(zeros, a, xv, yv)
672+
rval = ger(zeros([x.shape[0], y.shape[1]], x.dtype), a, xv, yv)
668673
new_out = [rval]
669674
else:
670675
# if bval is another constant, then z is being usefully
@@ -681,32 +686,32 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
681686
x, y = node.inputs
682687
xb = x.broadcastable
683688
yb = y.broadcastable
684-
one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype))
685-
zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype))
689+
one = as_tensor_variable(np.asarray(1, dtype=x.dtype))
690+
zero = as_tensor_variable(np.asarray(0, dtype=x.dtype))
686691
if xb[1] and yb[0]:
687692
# x and y are both vectors so this might qualifies for a GER
688693
xv = x.dimshuffle(0)
689694
yv = y.dimshuffle(1)
690-
zeros = ptb.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
695+
zeros = zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
691696
rval = ger(zeros, one, xv, yv)
692697
new_out = [rval]
693698
elif xb[0] and yb[1]:
694699
# x and y are both vectors so this qualifies for a sdot / ddot
695700
# PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
696701
xv = x.dimshuffle(1)
697-
zeros = ptb.AllocEmpty(x.dtype)(1)
702+
zeros = AllocEmpty(x.dtype)(1)
698703
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
699704
new_out = [rval.dimshuffle("x", 0)]
700705
elif xb[0] and not yb[0] and not yb[1]:
701706
# x is vector, y is matrix so try gemv
702707
xv = x.dimshuffle(1)
703-
zeros = ptb.AllocEmpty(x.dtype)(y.shape[1])
708+
zeros = AllocEmpty(x.dtype)(y.shape[1])
704709
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
705710
new_out = [rval.dimshuffle("x", 0)]
706711
elif not xb[0] and not xb[1] and yb[1]:
707712
# x is matrix, y is vector, try gemv
708713
yv = y.dimshuffle(0)
709-
zeros = ptb.AllocEmpty(x.dtype)(x.shape[0])
714+
zeros = AllocEmpty(x.dtype)(x.shape[0])
710715
rval = gemv_no_inplace(zeros, one, x, yv, zero)
711716
new_out = [rval.dimshuffle(0, "x")]
712717
else:
@@ -841,9 +846,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
841846
" matrix type"
842847
)
843848
return False
844-
a = ptb.cast(
845-
_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype
846-
)
849+
a = cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype)
847850
assert not a.type.ndim
848851
dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
849852

@@ -881,7 +884,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
881884
o.remove(d)
882885
o.remove(s)
883886

884-
a = ptb.cast(i_scalar[scalar_idx], d.type.dtype)
887+
a = cast(i_scalar[scalar_idx], d.type.dtype)
885888
assert not a.type.ndim
886889
if len(o) == 0:
887890
return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)]

0 commit comments

Comments
 (0)