11from __future__ import annotations
22
33import math
4- from typing import Literal , NamedTuple , Optional , Tuple , Union
4+ from typing import Literal , NamedTuple , cast
55
66import numpy as np
7+
78if np .__version__ [0 ] == "2" :
89 from numpy .lib .array_utils import normalize_axis_tuple
910else :
1011 from numpy .core .numeric import normalize_axis_tuple
1112
12- from ._aliases import matmul , matrix_transpose , tensordot , vecdot , isdtype
1313from .._internal import get_xp
14- from ._typing import Array , Namespace
14+ from ._aliases import isdtype , matmul , matrix_transpose , tensordot , vecdot
15+ from ._typing import Array , DType , Namespace
16+
1517
1618# These are in the main NumPy namespace but not in numpy.linalg
17- def cross (x1 : Array , x2 : Array , / , xp : Namespace , * , axis : int = - 1 , ** kwargs ) -> Array :
19+ def cross (
20+ x1 : Array ,
21+ x2 : Array ,
22+ / ,
23+ xp : Namespace ,
24+ * ,
25+ axis : int = - 1 ,
26+ ** kwargs : object ,
27+ ) -> Array :
1828 return xp .cross (x1 , x2 , axis = axis , ** kwargs )
1929
20- def outer (x1 : Array , x2 : Array , / , xp : Namespace , ** kwargs ) -> Array :
30+ def outer (x1 : Array , x2 : Array , / , xp : Namespace , ** kwargs : object ) -> Array :
2131 return xp .outer (x1 , x2 , ** kwargs )
2232
2333class EighResult (NamedTuple ):
@@ -39,46 +49,66 @@ class SVDResult(NamedTuple):
3949
4050# These functions are the same as their NumPy counterparts except they return
4151# a namedtuple.
42- def eigh (x : Array , / , xp : Namespace , ** kwargs ) -> EighResult :
52+ def eigh (x : Array , / , xp : Namespace , ** kwargs : object ) -> EighResult :
4353 return EighResult (* xp .linalg .eigh (x , ** kwargs ))
4454
45- def qr (x : Array , / , xp : Namespace , * , mode : Literal ['reduced' , 'complete' ] = 'reduced' ,
46- ** kwargs ) -> QRResult :
55+ def qr (
56+ x : Array ,
57+ / ,
58+ xp : Namespace ,
59+ * ,
60+ mode : Literal ["reduced" , "complete" ] = "reduced" ,
61+ ** kwargs : object ,
62+ ) -> QRResult :
4763 return QRResult (* xp .linalg .qr (x , mode = mode , ** kwargs ))
4864
49- def slogdet (x : Array , / , xp : Namespace , ** kwargs ) -> SlogdetResult :
65+ def slogdet (x : Array , / , xp : Namespace , ** kwargs : object ) -> SlogdetResult :
5066 return SlogdetResult (* xp .linalg .slogdet (x , ** kwargs ))
5167
5268def svd (
53- x : Array , / , xp : Namespace , * , full_matrices : bool = True , ** kwargs
69+ x : Array ,
70+ / ,
71+ xp : Namespace ,
72+ * ,
73+ full_matrices : bool = True ,
74+ ** kwargs : object ,
5475) -> SVDResult :
5576 return SVDResult (* xp .linalg .svd (x , full_matrices = full_matrices , ** kwargs ))
5677
5778# These functions have additional keyword arguments
5879
5980# The upper keyword argument is new from NumPy
60- def cholesky (x : Array , / , xp : Namespace , * , upper : bool = False , ** kwargs ) -> Array :
81+ def cholesky (
82+ x : Array ,
83+ / ,
84+ xp : Namespace ,
85+ * ,
86+ upper : bool = False ,
87+ ** kwargs : object ,
88+ ) -> Array :
6189 L = xp .linalg .cholesky (x , ** kwargs )
6290 if upper :
6391 U = get_xp (xp )(matrix_transpose )(L )
6492 if get_xp (xp )(isdtype )(U .dtype , 'complex floating' ):
65- U = xp .conj (U )
93+ U = xp .conj (U ) # pyright: ignore[reportConstantRedefinition]
6694 return U
6795 return L
6896
6997# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
7098# Note that it has a different semantic meaning from tol and rcond.
71- def matrix_rank (x : Array ,
72- / ,
73- xp : Namespace ,
74- * ,
75- rtol : Optional [Union [float , Array ]] = None ,
76- ** kwargs ) -> Array :
99+ def matrix_rank (
100+ x : Array ,
101+ / ,
102+ xp : Namespace ,
103+ * ,
104+ rtol : float | Array | None = None ,
105+ ** kwargs : object ,
106+ ) -> Array :
77107 # this is different from xp.linalg.matrix_rank, which supports 1
78108 # dimensional arrays.
79109 if x .ndim < 2 :
80110 raise xp .linalg .LinAlgError ("1-dimensional array given. Array must be at least two-dimensional" )
81- S = get_xp (xp )(svdvals )(x , ** kwargs )
111+ S : Array = get_xp (xp )(svdvals )(x , ** kwargs )
82112 if rtol is None :
83113 tol = S .max (axis = - 1 , keepdims = True ) * max (x .shape [- 2 :]) * xp .finfo (S .dtype ).eps
84114 else :
@@ -88,7 +118,12 @@ def matrix_rank(x: Array,
88118 return xp .count_nonzero (S > tol , axis = - 1 )
89119
90120def pinv (
91- x : Array , / , xp : Namespace , * , rtol : Optional [Union [float , Array ]] = None , ** kwargs
121+ x : Array ,
122+ / ,
123+ xp : Namespace ,
124+ * ,
125+ rtol : float | Array | None = None ,
126+ ** kwargs : object ,
92127) -> Array :
93128 # this is different from xp.linalg.pinv, which does not multiply the
94129 # default tolerance by max(M, N).
@@ -104,23 +139,23 @@ def matrix_norm(
104139 xp : Namespace ,
105140 * ,
106141 keepdims : bool = False ,
107- ord : Optional [ Union [ int , float , Literal [' fro' , ' nuc' ]]] = ' fro' ,
142+ ord : float | Literal [" fro" , " nuc" ] | None = " fro" ,
108143) -> Array :
109144 return xp .linalg .norm (x , axis = (- 2 , - 1 ), keepdims = keepdims , ord = ord )
110145
111146# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
112147# xp.linalg.svd(compute_uv=False).
113- def svdvals (x : Array , / , xp : Namespace ) -> Union [ Array , Tuple [Array , ...] ]:
148+ def svdvals (x : Array , / , xp : Namespace ) -> Array | tuple [Array , ...]:
114149 return xp .linalg .svd (x , compute_uv = False )
115150
116151def vector_norm (
117152 x : Array ,
118153 / ,
119154 xp : Namespace ,
120155 * ,
121- axis : Optional [ Union [ int , Tuple [int , ...]]] = None ,
156+ axis : int | tuple [int , ...] | None = None ,
122157 keepdims : bool = False ,
123- ord : Optional [ Union [ int , float ]] = 2 ,
158+ ord : float = 2 ,
124159) -> Array :
125160 # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
126161 # when axis=None and the input is 2-D, so to force a vector norm, we make
@@ -133,7 +168,10 @@ def vector_norm(
133168 elif isinstance (axis , tuple ):
134169 # Note: The axis argument supports any number of axes, whereas
135170 # xp.linalg.norm() only supports a single axis for vector norm.
136- normalized_axis = normalize_axis_tuple (axis , x .ndim )
171+ normalized_axis = cast (
172+ "tuple[int, ...]" ,
173+ normalize_axis_tuple (axis , x .ndim ), # pyright: ignore[reportCallIssue]
174+ )
137175 rest = tuple (i for i in range (x .ndim ) if i not in normalized_axis )
138176 newshape = axis + rest
139177 _x = xp .transpose (x , newshape ).reshape (
@@ -149,7 +187,13 @@ def vector_norm(
149187 # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
150188 # above to avoid matrix norm logic.
151189 shape = list (x .shape )
152- _axis = normalize_axis_tuple (range (x .ndim ) if axis is None else axis , x .ndim )
190+ _axis = cast (
191+ "tuple[int, ...]" ,
192+ normalize_axis_tuple ( # pyright: ignore[reportCallIssue]
193+ range (x .ndim ) if axis is None else axis ,
194+ x .ndim ,
195+ ),
196+ )
153197 for i in _axis :
154198 shape [i ] = 1
155199 res = xp .reshape (res , tuple (shape ))
@@ -159,11 +203,17 @@ def vector_norm(
159203# xp.diagonal and xp.trace operate on the first two axes whereas these
160204# operates on the last two
161205
162- def diagonal (x : Array , / , xp : Namespace , * , offset : int = 0 , ** kwargs ) -> Array :
206+ def diagonal (x : Array , / , xp : Namespace , * , offset : int = 0 , ** kwargs : object ) -> Array :
163207 return xp .diagonal (x , offset = offset , axis1 = - 2 , axis2 = - 1 , ** kwargs )
164208
165209def trace (
166- x : Array , / , xp : Namespace , * , offset : int = 0 , dtype = None , ** kwargs
210+ x : Array ,
211+ / ,
212+ xp : Namespace ,
213+ * ,
214+ offset : int = 0 ,
215+ dtype : DType | None = None ,
216+ ** kwargs : object ,
167217) -> Array :
168218 return xp .asarray (
169219 xp .trace (x , offset = offset , dtype = dtype , axis1 = - 2 , axis2 = - 1 , ** kwargs )
0 commit comments