1- from collections .abc import Sequence
2- from typing import Final , Generic , Protocol , TypeAlias , overload , type_check_only
1+ from typing import Final , Generic , Protocol , overload , type_check_only
32from typing_extensions import TypeVar
43
54import numpy as np
@@ -8,46 +7,49 @@ import optype.numpy.compat as npc
87
98__all__ = ["Covariance" ]
109
11- # `float16` and `longdouble` aren't supported in `scipy.linalg`, and neither is `bool_`
12- _Scalar_uif : TypeAlias = np . float32 | np . float64 | npc .integer
10+ _ScalarT = TypeVar ( "_ScalarT" , bound = npc . floating | npc . integer )
11+ _ScalarT_co = TypeVar ( "_ScalarT_co" , bound = npc . floating | npc .integer , default = np . float64 , covariant = True )
1312
14- _SCT = TypeVar ("_SCT" , bound = _Scalar_uif )
15- _SCT_co = TypeVar ("_SCT_co" , bound = _Scalar_uif , covariant = True , default = np .float64 )
16-
17- class Covariance (Generic [_SCT_co ]):
13+ class Covariance (Generic [_ScalarT_co ]):
1814 @staticmethod
1915 @overload
20- def from_diagonal (diagonal : Sequence [ int ] ) -> CovViaDiagonal [np .int_ ]: ...
16+ def from_diagonal (diagonal : onp . ToJustFloat64_1D ) -> CovViaDiagonal [np .float64 ]: ...
2117 @staticmethod
2218 @overload
23- def from_diagonal (diagonal : Sequence [ float ] ) -> CovViaDiagonal [np .int_ | np . float64 ]: ...
19+ def from_diagonal (diagonal : onp . ToJustInt64_1D ) -> CovViaDiagonal [np .int_ ]: ...
2420 @staticmethod
2521 @overload
26- def from_diagonal (diagonal : Sequence [_SCT ] | onp .CanArrayND [_SCT ]) -> CovViaDiagonal [_SCT ]: ...
22+ def from_diagonal (diagonal : onp .ToArray1D [_ScalarT , _ScalarT ]) -> CovViaDiagonal [_ScalarT ]: ...
23+
24+ #
2725 @staticmethod
2826 def from_precision (precision : onp .ToFloat2D , covariance : onp .ToFloat2D | None = None ) -> CovViaPrecision : ...
2927 @staticmethod
3028 def from_cholesky (cholesky : onp .ToFloat2D ) -> CovViaCholesky : ...
3129 @staticmethod
3230 def from_eigendecomposition (eigendecomposition : tuple [onp .ToFloat1D , onp .ToFloat2D ]) -> CovViaEigendecomposition : ...
33- def whiten ( self , / , x : onp . AnyIntegerArray | onp . AnyFloatingArray ) -> onp . ArrayND [ npc . floating ]: ...
34- def colorize ( self , / , x : onp . AnyIntegerArray | onp . AnyFloatingArray ) -> onp . ArrayND [ npc . floating ]: ...
31+
32+ #
3533 @property
3634 def log_pdet (self , / ) -> np .float64 : ...
3735 @property
3836 def rank (self , / ) -> np .int_ : ...
3937 @property
40- def covariance (self , / ) -> onp .Array2D [_SCT_co ]: ...
38+ def covariance (self , / ) -> onp .Array2D [_ScalarT_co ]: ...
4139 @property
4240 def shape (self , / ) -> tuple [int , int ]: ...
4341
44- class CovViaDiagonal (Covariance [_SCT_co ], Generic [_SCT_co ]):
42+ #
43+ def whiten (self , / , x : onp .ToFloatND ) -> onp .ArrayND [npc .floating ]: ...
44+ def colorize (self , / , x : onp .ToFloatND ) -> onp .ArrayND [npc .floating ]: ...
45+
46+ class CovViaDiagonal (Covariance [_ScalarT_co ], Generic [_ScalarT_co ]):
4547 @overload
46- def __init__ (self : CovViaDiagonal [np .int_ ], / , diagonal : Sequence [ int ] ) -> None : ...
48+ def __init__ (self : CovViaDiagonal [np .float64 ], / , diagonal : onp . ToJustFloat64_1D ) -> None : ...
4749 @overload
48- def __init__ (self : CovViaDiagonal [np .int_ | np . float64 ], / , diagonal : Sequence [ float ] ) -> None : ...
50+ def __init__ (self : CovViaDiagonal [np .int_ ], / , diagonal : onp . ToJustInt64_1D ) -> None : ...
4951 @overload
50- def __init__ (self , / , diagonal : Sequence [ float | _SCT_co ] | onp .CanArrayND [ _SCT_co ]) -> None : ...
52+ def __init__ (self , / , diagonal : onp .ToArray1D [ _ScalarT_co , _ScalarT_co ]) -> None : ...
5153
5254class CovViaPrecision (Covariance [np .float64 ]):
5355 def __init__ (self , / , precision : onp .ToFloat2D , covariance : onp .ToFloat2D | None = None ) -> None : ...
@@ -63,17 +65,17 @@ class _PSD(Protocol):
6365 _M : onp .ArrayND [np .float64 ]
6466 V : onp .ArrayND [np .float64 ]
6567 U : onp .ArrayND [np .float64 ]
66- eps : np . float64 | float
67- log_pdet : np . float64 | float
68- cond : np . float64 | float
68+ eps : float
69+ log_pdet : float
70+ cond : float
6971 rank : int
7072
7173 @property
7274 def pinv (self , / ) -> onp .ArrayND [npc .floating ]: ...
7375
7476class CovViaPSD (Covariance [np .float64 ]):
7577 _LP : Final [onp .ArrayND [np .float64 ]]
76- _log_pdet : Final [np . float64 | float ]
78+ _log_pdet : Final [float ]
7779 _rank : Final [int ]
7880 _covariance : Final [onp .ArrayND [np .float64 ]]
7981 _shape : tuple [int , int ]
0 commit comments