4040 get_normalized_batch_axes ,
4141 scalar_elemwise ,
4242)
43- from pytensor .tensor .shape import shape , specify_broadcastable
43+ from pytensor .tensor .shape import shape , specify_shape
4444from pytensor .tensor .type import (
4545 DenseTensorType ,
4646 complex_dtypes ,
4747 continuous_dtypes ,
4848 discrete_dtypes ,
49+ float_dtypes ,
4950 int_dtypes ,
5051 tensor ,
5152 uint_dtypes ,
@@ -2986,9 +2987,7 @@ def clip(x, min, max):
29862987
29872988class Dot (Op ):
29882989 """
2989- Computes the dot product of two variables. For two matrices, this is
2990- equivalent to matrix multiplication. For two vectors, this is the inner
2991- product.
2990+ Computes the dot product of two matrices variables
29922991
29932992 Notes
29942993 -----
@@ -3001,97 +3000,58 @@ class Dot(Op):
30013000
30023001 """
30033002
3003+ gufunc_signature = "(m,n),(n,p)->(m,p)"
3004+ gufunc_spec = ("matmul" , 2 , 1 )
30043005 __props__ = ()
30053006
3006- # the rationale for Dot22 is related to getting GEMM Ops into the
3007- # graph. See Dot22 in tensor.blas for details.
3008-
3009- def make_node (self , * inputs ):
3010- inputs = list (map (as_tensor_variable , inputs ))
3007+ def make_node (self , x , y ):
3008+ x = as_tensor_variable (x )
3009+ y = as_tensor_variable (y )
30113010
3012- if len (inputs ) != 2 :
3013- raise TypeError (f"Two arguments required, { len (inputs )} given " )
3014- if inputs [0 ].ndim not in (1 , 2 ):
3011+ if x .type .ndim != 2 :
30153012 raise TypeError (
3016- "Input 0 (0-indexed) must have ndim of "
3017- f"1 or 2, { int (inputs [0 ].ndim )} given. Consider calling "
3018- "pytensor.tensor.dot instead."
3013+ f"Dot Op expects a 2D tensor as input 0, got { x } with { x .type .ndim } dimensions"
30193014 )
3020- if inputs [ 1 ]. ndim not in ( 1 , 2 ) :
3015+ if y . type . ndim != 2 :
30213016 raise TypeError (
3022- "Input 1 (0-indexed) must have ndim of "
3023- f"1 or 2, { int (inputs [1 ].ndim )} given. Consider calling "
3024- "pytensor.tensor.dot instead."
3017+ f"Dot Op expects a 2D tensor as input 1, got { y } with { y .type .ndim } dimensions"
30253018 )
30263019
3027- sx , sy = ( input .type .shape for input in inputs )
3020+ sx , sy = x .type .shape , y . type . shape
30283021 if sx [- 1 ] is not None and sy [0 ] is not None and sx [- 1 ] != sy [0 ]:
30293022 raise ValueError (
30303023 f"Incompatible shared dimension for dot product: { sx } , { sy } "
30313024 )
3025+ out_shape = (sx [0 ], sy [1 ])
3026+ out_dtype = ps .upcast (x .type .dtype , y .type .dtype )
3027+ outputs = [tensor (dtype = out_dtype , shape = out_shape )]
3028+ return Apply (self , [x , y ], outputs )
30323029
3033- if len (sy ) == 2 :
3034- sz = sx [:- 1 ] + sy [- 1 :]
3035- elif len (sy ) == 1 :
3036- sz = sx [:- 1 ]
3037-
3038- i_dtypes = [input .type .dtype for input in inputs ]
3039- outputs = [tensor (dtype = ps .upcast (* i_dtypes ), shape = sz )]
3040- return Apply (self , inputs , outputs )
3041-
3042- def perform (self , node , inp , out ):
3043- x , y = inp
3044- (z ,) = out
3045-
3046- # the asarray is here because dot between two vectors
3047- # gives a numpy float object but we need to return a 0d
3048- # ndarray
3049- z [0 ] = np .asarray (np .dot (x , y ))
3030+ def perform (self , node , inputs , output_storage ):
3031+ output_storage [0 ][0 ] = np .matmul (* inputs )
30503032
30513033 def grad (self , inp , grads ):
30523034 x , y = inp
30533035 (gz ,) = grads
3054- xdim , ydim , gdim = x .type .ndim , y .type .ndim , gz .type .ndim
3055-
3056- # grad is scalar, so x is vector and y is vector
3057- if gdim == 0 :
3058- xgrad = gz * y
3059- ygrad = gz * x
3060-
3061- # x is vector, y is matrix, grad is vector
3062- elif xdim == 1 and ydim == 2 :
3063- xgrad = dot (gz , y .T )
3064- ygrad = outer (x .T , gz )
3065-
3066- # x is matrix, y is vector, grad is vector
3067- elif xdim == 2 and ydim == 1 :
3068- xgrad = outer (gz , y .T )
3069- ygrad = dot (x .T , gz )
30703036
3071- # x is matrix, y is matrix, grad is matrix
3072- elif xdim == ydim == 2 :
3073- xgrad = dot (gz , y .T )
3074- ygrad = dot (x .T , gz )
3037+ xgrad = self (gz , y .T )
3038+ ygrad = self (x .T , gz )
30753039
30763040 # If x or y contain broadcastable dimensions but only one of
30773041 # them know that a matching dimensions is broadcastable, the
30783042 # above code don't always return the right broadcast pattern.
30793043 # This cause problem down the road. See gh-1461.
3080- if xgrad .broadcastable != x .broadcastable :
3081- xgrad = specify_broadcastable (
3082- xgrad , * (ax for (ax , b ) in enumerate (x .type .broadcastable ) if b )
3083- )
3084- if ygrad .broadcastable != y .broadcastable :
3085- ygrad = specify_broadcastable (
3086- ygrad , * (ax for (ax , b ) in enumerate (y .type .broadcastable ) if b )
3087- )
3088-
3089- rval = xgrad , ygrad
3044+ if xgrad .type .shape != x .type .shape :
3045+ xgrad = specify_shape (xgrad , x .type .shape )
3046+ if ygrad .type .shape != y .type .shape :
3047+ ygrad = specify_shape (ygrad , y .type .shape )
30903048
3091- for elem in rval :
3092- assert elem .dtype .find ("float" ) != - 1
3049+ if xgrad .type .dtype not in float_dtypes :
3050+ raise TypeError ("Dot grad x output must be a float type" )
3051+ if ygrad .type .dtype not in float_dtypes :
3052+ raise TypeError ("Dot grad y output must be a float type" )
30933053
3094- return rval
3054+ return xgrad , ygrad
30953055
30963056 def R_op (self , inputs , eval_points ):
30973057 # R_op for a \dot b evaluated at c for a and d for b is
@@ -3116,24 +3076,7 @@ def R_op(self, inputs, eval_points):
31163076
31173077 def infer_shape (self , fgraph , node , shapes ):
31183078 xshp , yshp = shapes
3119- x , y = node .inputs
3120-
3121- # vector / vector
3122- if x .ndim == 1 and y .ndim == 1 :
3123- return [()]
3124- # matrix / vector
3125- if x .ndim == 2 and y .ndim == 1 :
3126- return [xshp [:- 1 ]]
3127- # vector / matrix
3128- if x .ndim == 1 and y .ndim == 2 :
3129- return [yshp [- 1 :]]
3130- # matrix / matrix
3131- if x .ndim == 2 and y .ndim == 2 :
3132- return [xshp [:- 1 ] + yshp [- 1 :]]
3133- raise NotImplementedError ()
3134-
3135- def __str__ (self ):
3136- return "dot"
3079+ return [[xshp [0 ], yshp [1 ]]]
31373080
31383081
31393082_dot = Dot ()
@@ -3215,7 +3158,24 @@ def dense_dot(a, b):
32153158 elif a .ndim > 2 or b .ndim > 2 :
32163159 return tensordot (a , b , [[a .ndim - 1 ], [np .maximum (0 , b .ndim - 2 )]])
32173160 else :
3218- return _dot (a , b )
3161+ row_vector = a .ndim == 1
3162+ if row_vector :
3163+ # Promote to row matrix
3164+ a = a [None ]
3165+
3166+ col_vector = b .ndim == 1
3167+ if col_vector :
3168+ # Promote to column matrix
3169+ b = b [:, None ]
3170+
3171+ out = _dot (a , b )
3172+ if row_vector :
3173+ # If we promoted a to a row matrix, we need to squeeze the first dimension
3174+ out = out .squeeze (0 )
3175+ if col_vector :
3176+ # If we promoted b to a column matrix, we need to squeeze the last dimension
3177+ out = out .squeeze (- 1 )
3178+ return out
32193179
32203180
32213181def tensordot (
@@ -3921,11 +3881,7 @@ def logsumexp(x, axis=None, keepdims=False):
39213881 return log (sum (exp (x ), axis = axis , keepdims = keepdims ))
39223882
39233883
3924- _matmul = Blockwise (
3925- _dot ,
3926- signature = "(m,k),(k,n)->(m,n)" ,
3927- gufunc_spec = ("numpy.matmul" , 2 , 1 ),
3928- )
3884+ _matmul = Blockwise (_dot , name = "Matmul" )
39293885
39303886
39313887def matmul (x1 : "ArrayLike" , x2 : "ArrayLike" , dtype : Optional ["DTypeLike" ] = None ):
@@ -3975,7 +3931,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
39753931 if x1 .type .ndim == 0 or x2 .type .ndim == 0 :
39763932 raise ValueError ("matmul operand cannot be scalar" )
39773933 if x1 .type .ndim == 1 and x2 .type .ndim == 1 :
3978- out = _dot (x1 , x2 )
3934+ out = vecdot (x1 , x2 )
39793935 elif x1 .type .ndim == 1 :
39803936 out = vecmat (x1 , x2 )
39813937 elif x2 .type .ndim == 1 :
@@ -4139,23 +4095,7 @@ def vecmat(
41394095
41404096@_vectorize_node .register (Dot )
41414097def vectorize_node_dot (op , node , batched_x , batched_y ):
4142- old_x , old_y = node .inputs
4143- old_x_ndim = old_x .type .ndim
4144- old_y_ndim = old_y .type .ndim
4145- match (old_x_ndim , old_y_ndim ):
4146- case (1 , 1 ):
4147- batch_fn = vecdot
4148- case (2 , 1 ):
4149- batch_fn = matvec
4150- case (1 , 2 ):
4151- batch_fn = vecmat
4152- case (2 , 2 ):
4153- batch_fn = matmul
4154- case _:
4155- raise ValueError (
4156- f"Core dot Op should have 1D or 2D inputs, got { old_x_ndim } D and { old_y_ndim } D."
4157- )
4158- return batch_fn (batched_x , batched_y ).owner
4098+ return matmul (batched_x , batched_y ).owner
41594099
41604100
41614101def nan_to_num (x , nan = 0.0 , posinf = None , neginf = None ):
0 commit comments