8383)
8484from pytensor .graph .rewriting .db import SequenceDB
8585from pytensor .graph .utils import InconsistencyError
86- from pytensor .tensor import basic as ptb
86+ from pytensor .tensor import as_tensor_variable
87+ from pytensor .tensor .basic import (
88+ AllocEmpty ,
89+ cast ,
90+ get_underlying_scalar_constant_value ,
91+ zeros ,
92+ )
8793from pytensor .tensor .blas import (
8894 Dot22 ,
8995 _batched_dot ,
@@ -143,7 +149,7 @@ def _as_scalar(res, dtype=None):
143149 # as the cast of the scalar can be done before or after the dot22
144150 # and this will give the same result.
145151 if pytensor .scalar .upcast (res .dtype , dtype ) == dtype :
146- return ptb . cast (rval , dtype )
152+ return cast (rval , dtype )
147153 else :
148154 return None
149155
@@ -358,13 +364,13 @@ def _gemm_from_factored_list(fgraph, lst):
358364 # sM can be a tuple of 2 elements or an PyTensor variable.
359365 if isinstance (sM , tuple ):
360366 sm0 , sm1 = sM
361- sm0 = ptb .as_tensor_variable (sm0 )
362- sm0_dtype = sm0 .type .dtype
363367 sm1_dtype = sm1 .type .dtype
368+ sm0 = as_tensor_variable (sm0 , dtype = sm1_dtype )
369+ sm0_dtype = sm0 .type .dtype
364370 if sm0_dtype == sm1_dtype :
365371 lst2 .append ((sm0 , sm1 ))
366372 elif upcast (sm0_dtype , sm1_dtype ) == sm1_dtype :
367- lst2 .append ((ptb . cast (sm0 , sm1_dtype ), sm1 ))
373+ lst2 .append ((cast (sm0 , sm1_dtype ), sm1 ))
368374
369375 lst = lst2
370376
@@ -654,7 +660,7 @@ def local_gemm_to_ger(fgraph, node):
654660 xv = x .dimshuffle (0 )
655661 yv = y .dimshuffle (1 )
656662 try :
657- bval = ptb . get_underlying_scalar_constant_value (b )
663+ bval = get_underlying_scalar_constant_value (b )
658664 except NotScalarConstantError :
659665 # b isn't a constant, GEMM is doing useful pre-scaling
660666 return
@@ -663,8 +669,7 @@ def local_gemm_to_ger(fgraph, node):
663669 rval = ger (z , a , xv , yv )
664670 new_out = [rval ]
665671 elif bval == 0 : # GER on zeros_like should be faster than GEMM
666- zeros = ptb .zeros ([x .shape [0 ], y .shape [1 ]], x .dtype )
667- rval = ger (zeros , a , xv , yv )
672+ rval = ger (zeros ([x .shape [0 ], y .shape [1 ]], x .dtype ), a , xv , yv )
668673 new_out = [rval ]
669674 else :
670675 # if bval is another constant, then z is being usefully
@@ -681,32 +686,32 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
681686 x , y = node .inputs
682687 xb = x .broadcastable
683688 yb = y .broadcastable
684- one = ptb . as_tensor_variable (np .asarray (1 , dtype = x .dtype ))
685- zero = ptb . as_tensor_variable (np .asarray (0 , dtype = x .dtype ))
689+ one = as_tensor_variable (np .asarray (1 , dtype = x .dtype ))
690+ zero = as_tensor_variable (np .asarray (0 , dtype = x .dtype ))
686691 if xb [1 ] and yb [0 ]:
687692 # x and y are both vectors so this might qualifies for a GER
688693 xv = x .dimshuffle (0 )
689694 yv = y .dimshuffle (1 )
690- zeros = ptb . zeros ([x .shape [0 ], y .shape [1 ]], dtype = x .dtype )
695+ zeros = zeros ([x .shape [0 ], y .shape [1 ]], dtype = x .dtype )
691696 rval = ger (zeros , one , xv , yv )
692697 new_out = [rval ]
693698 elif xb [0 ] and yb [1 ]:
694699 # x and y are both vectors so this qualifies for a sdot / ddot
695700 # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
696701 xv = x .dimshuffle (1 )
697- zeros = ptb . AllocEmpty (x .dtype )(1 )
702+ zeros = AllocEmpty (x .dtype )(1 )
698703 rval = gemv_no_inplace (zeros , one , y .T , xv , zero )
699704 new_out = [rval .dimshuffle ("x" , 0 )]
700705 elif xb [0 ] and not yb [0 ] and not yb [1 ]:
701706 # x is vector, y is matrix so try gemv
702707 xv = x .dimshuffle (1 )
703- zeros = ptb . AllocEmpty (x .dtype )(y .shape [1 ])
708+ zeros = AllocEmpty (x .dtype )(y .shape [1 ])
704709 rval = gemv_no_inplace (zeros , one , y .T , xv , zero )
705710 new_out = [rval .dimshuffle ("x" , 0 )]
706711 elif not xb [0 ] and not xb [1 ] and yb [1 ]:
707712 # x is matrix, y is vector, try gemv
708713 yv = y .dimshuffle (0 )
709- zeros = ptb . AllocEmpty (x .dtype )(x .shape [0 ])
714+ zeros = AllocEmpty (x .dtype )(x .shape [0 ])
710715 rval = gemv_no_inplace (zeros , one , x , yv , zero )
711716 new_out = [rval .dimshuffle (0 , "x" )]
712717 else :
@@ -841,9 +846,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
841846 " matrix type"
842847 )
843848 return False
844- a = ptb .cast (
845- _as_scalar (m .owner .inputs [scalar_idx ], dtype = d .dtype ), d .type .dtype
846- )
849+ a = cast (_as_scalar (m .owner .inputs [scalar_idx ], dtype = d .dtype ), d .type .dtype )
847850 assert not a .type .ndim
848851 dot = _dot22scalar (d .owner .inputs [0 ], d .owner .inputs [1 ], a )
849852
@@ -881,7 +884,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
881884 o .remove (d )
882885 o .remove (s )
883886
884- a = ptb . cast (i_scalar [scalar_idx ], d .type .dtype )
887+ a = cast (i_scalar [scalar_idx ], d .type .dtype )
885888 assert not a .type .ndim
886889 if len (o ) == 0 :
887890 return [_dot22scalar (d .owner .inputs [0 ], d .owner .inputs [1 ], a )]
0 commit comments