@@ -45,12 +45,44 @@ def slogdet(x: ndarray, /) -> SlogdetResult:
4545def svd (x : ndarray , / , * , full_matrices : bool = True ) -> SVDResult :
4646 return SVDResult (* np .linalg .svd (x , full_matrices = full_matrices ))
4747
48- # This function is not in NumPy.
48+ # These functions have additional keyword arguments
49+
50+ # The upper keyword argument is new from NumPy
51+ def cholesky (x : ndarray , / , * , upper : bool = False ) -> ndarray :
52+ L = np .linalg .cholesky (x )
53+ if upper :
54+ return matrix_transpose (L )
55+ return L
56+
57+ # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
58+ # Note that it has a different semantic meaning from tol and rcond.
59+ def matrix_rank (x : ndarray , / , * , rtol : Optional [Union [float , ndarray ]] = None ) -> ndarray :
60+ # this is different from np.linalg.matrix_rank, which supports 1
61+ # dimensional arrays.
62+ if x .ndim < 2 :
63+ raise np .linalg .LinAlgError ("1-dimensional array given. Array must be at least two-dimensional" )
64+ S = np .linalg .svd (x , compute_uv = False )
65+ if rtol is None :
66+ tol = S .max (axis = - 1 , keepdims = True ) * max (x .shape [- 2 :]) * np .finfo (S .dtype ).eps
67+ else :
68+ # this is different from np.linalg.matrix_rank, which does not
69+ # multiply the tolerance by the largest singular value.
70+ tol = S .max (axis = - 1 , keepdims = True )* np .asarray (rtol )[..., np .newaxis ]
71+ return np .count_nonzero (S > tol , axis = - 1 )
72+
73+ def pinv (x : ndarray , / , * , rtol : Optional [Union [float , ndarray ]] = None ) -> ndarray :
74+ # this is different from np.linalg.pinv, which does not multiply the
75+ # default tolerance by max(M, N).
76+ if rtol is None :
77+ rtol = max (x .shape [- 2 :]) * np .finfo (x .dtype ).eps
78+ return np .linalg .pinv (x , rcond = rtol )
79+
80+ # These functions are new in the array API spec
81+
4982def matrix_norm (x : ndarray , / , * , keepdims : bool = False , ord : Optional [Union [int , float , Literal ['fro' , 'nuc' ]]] = 'fro' ) -> ndarray :
5083 return np .linalg .norm (x , axis = (- 2 , - 1 ), keepdims = keepdims , ord = ord )
5184
52- # This function is new in the array API spec. Unlike transpose, it only
53- # transposes the last two axes.
85+ # Unlike transpose, matrix_transpose only transposes the last two axes.
5486def matrix_transpose (x : ndarray , / ) -> ndarray :
5587 if x .ndim < 2 :
5688 raise ValueError ("x must be at least 2-dimensional for matrix_transpose" )
@@ -61,7 +93,6 @@ def matrix_transpose(x: ndarray, /) -> ndarray:
6193def svdvals (x : ndarray , / ) -> Union [ndarray , Tuple [ndarray , ...]]:
6294 return np .linalg .svd (x , compute_uv = False )
6395
64- # vecdot is not in NumPy
6596def vecdot (x1 : ndarray , x2 : ndarray , / , * , axis : int = - 1 ) -> ndarray :
6697 ndim = max (x1 .ndim , x2 .ndim )
6798 x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
@@ -111,6 +142,7 @@ def vector_norm(x: ndarray, /, *, axis: Optional[Union[int, Tuple[int, ...]]] =
111142 return res
112143
113144__all__ = linalg_all .copy ()
114- __all__ += ['cross' , 'diagonal' , 'matmul' , 'matrix_norm' , 'matrix_transpose' ,
115- 'outer' , 'svdvals' , 'tensordot' , 'trace' , 'vecdot' , 'vector_norm' ,
116- 'EighResult' , 'QRResult' , 'SlogdetResult' , 'SVDResult' ]
145+ __all__ += ['cross' , 'diagonal' , 'matmul' , 'cholesky' , 'matrix_rank' , 'pinv' ,
146+ 'matrix_norm' , 'matrix_transpose' , 'outer' , 'svdvals' ,
147+ 'tensordot' , 'trace' , 'vecdot' , 'vector_norm' , 'EighResult' ,
148+ 'QRResult' , 'SlogdetResult' , 'SVDResult' ]
0 commit comments