@@ -60,6 +60,22 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
6060
6161def solve (x1 : array , x2 : array , / , ** kwargs ) -> array :
6262 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
63+ # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
64+ # whenever
65+ # 1. x1.ndim - 1 == x2.ndim
66+ # 2. x1.shape[:-1] == x2.shape
67+ #
68+ # See linalg_solve_is_vector_rhs in
69+ # aten/src/ATen/native/LinearAlgebraUtils.h and
70+ # TORCH_META_FUNC(_linalg_solve_ex) in
71+ # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
72+ #
73+ # The easiest way to work around this is to prepend a size 1 dimension to
74+ # x2, since x2 is already one dimension less than x1.
75+ #
76+ # See https://github.com/pytorch/pytorch/issues/52915
77+ if x2 .ndim != 1 and x1 .ndim - 1 == x2 .ndim and x1 .shape [:- 1 ] == x2 .shape :
78+ x2 = x2 [None ]
6379 return torch .linalg .solve (x1 , x2 , ** kwargs )
6480
6581# torch.trace doesn't support the offset argument and doesn't support stacking
@@ -78,7 +94,23 @@ def vector_norm(
7894) -> array :
7995 # torch.vector_norm incorrectly treats axis=() the same as axis=None
8096 if axis == ():
81- keepdims = True
97+ out = kwargs .get ('out' )
98+ if out is None :
99+ dtype = None
100+ if x .dtype == torch .complex64 :
101+ dtype = torch .float32
102+ elif x .dtype == torch .complex128 :
103+ dtype = torch .float64
104+
105+ out = torch .zeros_like (x , dtype = dtype )
106+
107+ # The norm of a single scalar works out to abs(x) in every case except
108+ # for ord=0, which is x != 0.
109+ if ord == 0 :
110+ out [:] = (x != 0 )
111+ else :
112+ out [:] = torch .abs (x )
113+ return out
82114 return torch .linalg .vector_norm (x , ord = ord , axis = axis , keepdim = keepdims , ** kwargs )
83115
84116__all__ = linalg_all + ['outer' , 'matmul' , 'matrix_transpose' , 'tensordot' ,
0 commit comments