1- from typing import Literal , TypeAlias , overload
2- from typing_extensions import TypeVar
1+ from typing import Literal , TypeAlias , TypeVar , overload
32
43import numpy as np
54import optype as op
@@ -8,87 +7,90 @@ import optype.numpy.compat as npc
87
98__all__ = ["diagsvd" , "null_space" , "orth" , "subspace_angles" , "svd" , "svdvals" ]
109
11- _T = TypeVar ("_T" )
12- _Tuple3 : TypeAlias = tuple [_T , _T , _T ]
13-
14- _Float : TypeAlias = np .float32 | np .float64
15- _FloatND : TypeAlias = onp .ArrayND [_Float ]
16-
17- _Complex : TypeAlias = np .complex64 | np .complex128
18-
19- _LapackDriver : TypeAlias = Literal ["gesdd" , "gesvd" ]
20-
2110_RealT = TypeVar ("_RealT" , bound = np .bool_ | npc .integer | npc .floating )
2211_InexactT = TypeVar ("_InexactT" , bound = _Float | _Complex )
12+ _ScalarT = TypeVar ("_ScalarT" , bound = np .generic )
13+ _ScalarT1 = TypeVar ("_ScalarT1" , bound = np .generic )
14+
15+ _SVD_ND : TypeAlias = tuple [onp .ArrayND [_ScalarT ], onp .ArrayND [_ScalarT1 ], onp .ArrayND [_ScalarT ]]
16+
17+ _Float : TypeAlias = np .float64 | np .float32
18+ _Complex : TypeAlias = np .complex128 | np .complex64
2319
2420_as_f32 : TypeAlias = np .float32 | np .float16 # noqa: PYI042
2521_as_f64 : TypeAlias = np .longdouble | np .float64 | npc .integer | np .bool_ # noqa: PYI042
22+ _as_c128 : TypeAlias = np .complex128 | np .clongdouble # noqa: PYI042
23+
24+ _ToSafeFloat64ND : TypeAlias = onp .ToArrayND [float , np .float64 | npc .integer | np .bool_ ]
25+ _ToArrayND : TypeAlias = onp .CanArrayND [_ScalarT ] | onp .SequenceND [_ScalarT ]
26+
27+ _LapackDriver : TypeAlias = Literal ["gesdd" , "gesvd" ]
2628
2729###
2830
2931@overload # nd float64
3032def svd (
3133 a : onp .ToArrayND [float , _as_f64 ],
32- full_matrices : onp . ToBool = True ,
33- compute_uv : onp . ToTrue = True ,
34- overwrite_a : onp . ToBool = False ,
35- check_finite : onp . ToBool = True ,
34+ full_matrices : bool = True ,
35+ compute_uv : Literal [ True ] = True ,
36+ overwrite_a : bool = False ,
37+ check_finite : bool = True ,
3638 lapack_driver : _LapackDriver = "gesdd" ,
37- ) -> _Tuple3 [ onp . ArrayND [ np .float64 ] ]: ...
39+ ) -> _SVD_ND [ np . float64 , np .float64 ]: ...
3840@overload # nd float32
3941def svd (
40- a : onp .ToArrayND [ _as_f32 , _as_f32 ],
41- full_matrices : onp . ToBool = True ,
42- compute_uv : onp . ToTrue = True ,
43- overwrite_a : onp . ToBool = False ,
44- check_finite : onp . ToBool = True ,
42+ a : onp .CanArrayND [ _as_f32 ],
43+ full_matrices : bool = True ,
44+ compute_uv : Literal [ True ] = True ,
45+ overwrite_a : bool = False ,
46+ check_finite : bool = True ,
4547 lapack_driver : _LapackDriver = "gesdd" ,
46- ) -> _Tuple3 [ onp . ArrayND [ np .float32 ] ]: ...
48+ ) -> _SVD_ND [ np . float32 , np .float32 ]: ...
4749@overload # nd complex128
4850def svd (
49- a : onp .ToArrayND [op .JustComplex , np . complex128 | np . clongdouble ],
50- full_matrices : onp . ToBool = True ,
51- compute_uv : onp . ToTrue = True ,
52- overwrite_a : onp . ToBool = False ,
53- check_finite : onp . ToBool = True ,
51+ a : onp .ToArrayND [op .JustComplex , _as_c128 ],
52+ full_matrices : bool = True ,
53+ compute_uv : Literal [ True ] = True ,
54+ overwrite_a : bool = False ,
55+ check_finite : bool = True ,
5456 lapack_driver : _LapackDriver = "gesdd" ,
55- ) -> tuple [ onp . ArrayND [ np .complex128 ], onp . ArrayND [ np .float64 ], onp . ArrayND [ np . complex128 ] ]: ...
57+ ) -> _SVD_ND [ np .complex128 , np .float64 ]: ...
5658@overload # nd complex64
5759def svd (
58- a : onp .ToArrayND [ np . complex64 , np .complex64 ],
59- full_matrices : onp . ToBool = True ,
60- compute_uv : onp . ToTrue = True ,
61- overwrite_a : onp . ToBool = False ,
62- check_finite : onp . ToBool = True ,
60+ a : onp .CanArrayND [ np .complex64 ],
61+ full_matrices : bool = True ,
62+ compute_uv : Literal [ True ] = True ,
63+ overwrite_a : bool = False ,
64+ check_finite : bool = True ,
6365 lapack_driver : _LapackDriver = "gesdd" ,
64- ) -> tuple [ onp . ArrayND [ np .complex64 ], onp . ArrayND [ np .float32 ], onp . ArrayND [ np . complex64 ] ]: ...
66+ ) -> _SVD_ND [ np .complex64 , np .float32 ]: ...
6567@overload # nd float64 | complex128, compute_uv=False (keyword)
6668def svd (
67- a : onp .ToArrayND [complex , _as_f64 | np . complex128 | np . clongdouble ],
68- full_matrices : onp . ToBool = True ,
69+ a : onp .ToArrayND [complex , _as_f64 | _as_c128 ],
70+ full_matrices : bool = True ,
6971 * ,
70- compute_uv : onp . ToFalse ,
71- overwrite_a : onp . ToBool = False ,
72- check_finite : onp . ToBool = True ,
72+ compute_uv : Literal [ False ] ,
73+ overwrite_a : bool = False ,
74+ check_finite : bool = True ,
7375 lapack_driver : _LapackDriver = "gesdd" ,
7476) -> onp .ArrayND [np .float64 ]: ...
7577@overload # nd float32 | complex64, compute_uv=False (keyword)
7678def svd (
77- a : onp .ToArrayND [ _as_f32 , _as_f32 | np .complex64 ],
78- full_matrices : onp . ToBool = True ,
79+ a : onp .CanArrayND [ _as_f32 | np .complex64 ],
80+ full_matrices : bool = True ,
7981 * ,
80- compute_uv : onp . ToFalse ,
81- overwrite_a : onp . ToBool = False ,
82- check_finite : onp . ToBool = True ,
82+ compute_uv : Literal [ False ] ,
83+ overwrite_a : bool = False ,
84+ check_finite : bool = True ,
8385 lapack_driver : _LapackDriver = "gesdd" ,
8486) -> onp .ArrayND [np .float32 ]: ...
8587
8688#
87- def svdvals (a : onp .ToComplexND , overwrite_a : onp . ToBool = False , check_finite : onp . ToBool = True ) -> _FloatND : ...
89+ def svdvals (a : onp .ToComplexND , overwrite_a : bool = False , check_finite : bool = True ) -> onp . ArrayND [ np . float64 | np . float32 ] : ...
8890
8991#
9092@overload
91- def diagsvd (s : onp . SequenceND [ _RealT ] | onp . CanArrayND [_RealT ], M : op .CanIndex , N : op .CanIndex ) -> onp .ArrayND [_RealT ]: ...
93+ def diagsvd (s : _ToArrayND [_RealT ], M : op .CanIndex , N : op .CanIndex ) -> onp .ArrayND [_RealT ]: ...
9294@overload
9395def diagsvd (s : onp .SequenceND [bool ], M : op .CanIndex , N : op .CanIndex ) -> onp .ArrayND [np .bool_ ]: ...
9496@overload
@@ -98,42 +100,40 @@ def diagsvd(s: onp.SequenceND[op.JustFloat], M: op.CanIndex, N: op.CanIndex) ->
98100
99101#
100102@overload
101- def orth (A : onp . ToIntND | onp . ToJustFloat64_ND , rcond : onp . ToFloat | None = None ) -> onp .ArrayND [np .float64 ]: ...
103+ def orth (A : _ToSafeFloat64ND , rcond : float | None = None ) -> onp .ArrayND [np .float64 ]: ...
102104@overload
103- def orth (A : onp .ToJustComplex128_ND , rcond : onp . ToFloat | None = None ) -> onp .ArrayND [np .complex128 ]: ...
105+ def orth (A : onp .ToJustComplex128_ND , rcond : float | None = None ) -> onp .ArrayND [np .complex128 ]: ...
104106@overload
105- def orth (
106- A : onp .SequenceND [_InexactT ] | onp .CanArrayND [_InexactT ], rcond : onp .ToFloat | None = None
107- ) -> onp .ArrayND [_InexactT ]: ...
107+ def orth (A : _ToArrayND [_InexactT ], rcond : float | None = None ) -> onp .ArrayND [_InexactT ]: ...
108108
109109#
110110@overload
111111def null_space (
112- A : onp . ToIntND | onp . ToJustFloat64_ND ,
113- rcond : onp . ToFloat | None = None ,
112+ A : _ToSafeFloat64ND ,
113+ rcond : float | None = None ,
114114 * ,
115- overwrite_a : onp . ToBool = False ,
116- check_finite : onp . ToBool = True ,
115+ overwrite_a : bool = False ,
116+ check_finite : bool = True ,
117117 lapack_driver : _LapackDriver = "gesdd" ,
118118) -> onp .ArrayND [np .float64 ]: ...
119119@overload
120120def null_space (
121121 A : onp .ToJustComplex128_ND ,
122- rcond : onp . ToFloat | None = None ,
122+ rcond : float | None = None ,
123123 * ,
124- overwrite_a : onp . ToBool = False ,
125- check_finite : onp . ToBool = True ,
124+ overwrite_a : bool = False ,
125+ check_finite : bool = True ,
126126 lapack_driver : _LapackDriver = "gesdd" ,
127127) -> onp .ArrayND [np .complex128 ]: ...
128128@overload
129129def null_space (
130- A : onp . SequenceND [ _InexactT ] | onp . CanArrayND [_InexactT ],
131- rcond : onp . ToFloat | None = None ,
130+ A : _ToArrayND [_InexactT ],
131+ rcond : float | None = None ,
132132 * ,
133- overwrite_a : onp . ToBool = False ,
134- check_finite : onp . ToBool = True ,
133+ overwrite_a : bool = False ,
134+ check_finite : bool = True ,
135135 lapack_driver : _LapackDriver = "gesdd" ,
136136) -> onp .ArrayND [_InexactT ]: ...
137137
138138#
139- def subspace_angles (A : onp .ToComplexND , B : onp .ToComplexND ) -> _FloatND : ...
139+ def subspace_angles (A : onp .ToComplexND , B : onp .ToComplexND ) -> onp . ArrayND [ np . float64 | np . float32 ] : ...
0 commit comments