@@ -22,13 +22,15 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
2222 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
2323 return torch_linalg .cross (x1 , x2 , dim = axis )
2424
25- def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
25+ def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 , ** kwargs ) -> array :
2626 from ._aliases import isdtype
2727
2828 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
2929
3030 # torch.linalg.vecdot doesn't support integer dtypes
3131 if isdtype (x1 .dtype , 'integral' ) or isdtype (x2 .dtype , 'integral' ):
32+ if kwargs :
33+ raise RuntimeError ("vecdot kwargs not supported for integral dtypes" )
3234 ndim = max (x1 .ndim , x2 .ndim )
3335 x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
3436 x2_shape = (1 ,)* (ndim - x2 .ndim ) + tuple (x2 .shape )
@@ -41,7 +43,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
4143
4244 res = x1_ [..., None , :] @ x2_ [..., None ]
4345 return res [..., 0 , 0 ]
44- return torch .linalg .vecdot (x1 , x2 , axis = axis )
46+ return torch .linalg .vecdot (x1 , x2 , dim = axis , ** kwargs )
4547
4648__all__ = linalg_all + ['outer' , 'trace' , 'matrix_transpose' , 'tensordot' , 'vecdot' ]
4749
0 commit comments