55 import torch
66 array = torch .Tensor
77 from torch import dtype as Dtype
8- from typing import Optional
8+ from typing import Optional , Union , Tuple , Literal
9+ inf = float ('inf' )
910
1011from ._aliases import _fix_promotion , sum
1112
2324
2425# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
2526# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
27+
28+ # torch.cross also does not support broadcasting when it would add new
29+ # dimensions https://github.com/pytorch/pytorch/issues/39656
2630def cross (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
2731 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
32+ if not (- min (x1 .ndim , x2 .ndim ) <= axis < max (x1 .ndim , x2 .ndim )):
33+ raise ValueError (f"axis { axis } out of bounds for cross product of arrays with shapes { x1 .shape } and { x2 .shape } " )
34+ if not (x1 .shape [axis ] == x2 .shape [axis ] == 3 ):
35+ raise ValueError (f"cross product axis must have size 3, got { x1 .shape [axis ]} and { x2 .shape [axis ]} " )
36+ x1 , x2 = torch .broadcast_tensors (x1 , x2 )
2837 return torch_linalg .cross (x1 , x2 , dim = axis )
2938
3039def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 , ** kwargs ) -> array :
3140 from ._aliases import isdtype
3241
3342 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
3443
44+ # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
45+ if x1 .shape [axis ] != x2 .shape [axis ]:
46+ raise ValueError ("x1 and x2 must have the same size along the given axis" )
47+
3548 # torch.linalg.vecdot doesn't support integer dtypes
3649 if isdtype (x1 .dtype , 'integral' ) or isdtype (x2 .dtype , 'integral' ):
3750 if kwargs :
3851 raise RuntimeError ("vecdot kwargs not supported for integral dtypes" )
39- ndim = max (x1 .ndim , x2 .ndim )
40- x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
41- x2_shape = (1 ,)* (ndim - x2 .ndim ) + tuple (x2 .shape )
42- if x1_shape [axis ] != x2_shape [axis ]:
43- raise ValueError ("x1 and x2 must have the same size along the given axis" )
4452
45- x1_ , x2_ = torch .broadcast_tensors (x1 , x2 )
46- x1_ = torch .moveaxis (x1_ , axis , - 1 )
47- x2_ = torch .moveaxis ( x2_ , axis , - 1 )
53+ x1_ = torch .moveaxis (x1 , axis , - 1 )
54+ x2_ = torch .moveaxis (x2 , axis , - 1 )
55+ x1_ , x2_ = torch .broadcast_tensors ( x1_ , x2_ )
4856
4957 res = x1_ [..., None , :] @ x2_ [..., None ]
5058 return res [..., 0 , 0 ]
@@ -59,8 +67,22 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
5967 # Use our wrapped sum to make sure it does upcasting correctly
6068 return sum (torch .diagonal (x , offset = offset , dim1 = - 2 , dim2 = - 1 ), axis = - 1 , dtype = dtype )
6169
62- __all__ = linalg_all + ['outer' , 'trace' , 'matmul' , 'matrix_transpose' , 'tensordot' ,
63- 'vecdot' , 'solve' ]
70+ def vector_norm (
71+ x : array ,
72+ / ,
73+ * ,
74+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
75+ keepdims : bool = False ,
76+ ord : Union [int , float , Literal [inf , - inf ]] = 2 ,
77+ ** kwargs ,
78+ ) -> array :
79+ # torch.vector_norm incorrectly treats axis=() the same as axis=None
80+ if axis == ():
81+ keepdims = True
82+ return torch .linalg .vector_norm (x , ord = ord , axis = axis , keepdim = keepdims , ** kwargs )
83+
84+ __all__ = linalg_all + ['outer' , 'matmul' , 'matrix_transpose' , 'tensordot' ,
85+ 'cross' , 'vecdot' , 'solve' , 'trace' , 'vector_norm' ]
6486
6587_all_ignore = ['torch_linalg' , 'sum' ]
6688
0 commit comments