11from abc import ABCMeta , abstractmethod
22from dataclasses import fields
3- from typing import Any
3+ from typing import Any , List , Tuple , Union , Dict , Type
4+ import numpy as np
5+ from numpy .typing import NDArray
46
57from scipy .stats import rv_continuous
68from scipy .stats .distributions import rv_frozen
79
810from src .algorithms .support_algorithms .integrator import Integrator
11+ from src .algorithms .support_algorithms .rqmc import RQMCIntegrator # default integrator
912
1013class AbstractMixtures (metaclass = ABCMeta ):
1114 """Base class for Mixtures"""
1215
1316 _classical_collector : Any
1417 _canonical_collector : Any
1518
16- @abstractmethod
17- def __init__ (self , mixture_form : str , ** kwargs : Any ) -> None :
19+ def __init__ (
20+ self ,
21+ mixture_form : str ,
22+ integrator_cls : Type [Integrator ] = RQMCIntegrator ,
23+ integrator_params : Dict [str , Any ] = None ,
24+ ** kwargs : Any
25+ ) -> None :
1826 """
19-
2027 Args:
21- mixture_form: Form of Mixture classical or Canonical
22- **kwargs: Parameters of Mixture
28+ mixture_form: Form of Mixture classical or canonical
29+ integrator_cls: Class implementing Integrator protocol (default: RQMCIntegrator)
30+ integrator_params: Parameters for integrator constructor (default: {{}})
31+ **kwargs: Parameters of Mixture (alpha, gamma, etc.)
2332 """
2433 self .mixture_form = mixture_form
34+ self .integrator_cls = integrator_cls
35+ self .integrator_params = integrator_params or {}
36+
2537 if mixture_form == "classical" :
2638 self .params = self ._params_validation (self ._classical_collector , kwargs )
2739 elif mixture_form == "canonical" :
@@ -30,40 +42,88 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
3042 raise AssertionError (f"Unknown mixture form: { mixture_form } " )
3143
3244 @abstractmethod
33- def compute_moment (self , n : int , integrator : Integrator ) -> tuple [float , float ]: ...
45+ def _compute_moment (self , n : int ) -> Tuple [float , float ]:
46+ ...
47+
48+ def compute_moment (
49+ self ,
50+ x : Union [List [int ], int , NDArray [np .float64 ]]
51+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
52+ if isinstance (x , np .ndarray ):
53+ return np .array ([self ._compute_moment (xp ) for xp in x ], dtype = object )
54+ elif isinstance (x , list ):
55+ return [self ._compute_moment (xp ) for xp in x ]
56+ elif isinstance (x , int ):
57+ return self ._compute_moment (x )
58+ else :
59+ raise TypeError (f"Unsupported type for x: { type (x )} " )
3460
3561 @abstractmethod
36- def compute_cdf (self , x : float , integrator : Integrator ) -> tuple [float , float ]: ...
62+ def _compute_pdf (self , x : float ) -> Tuple [float , float ]:
63+ ...
64+
65+ def compute_pdf (
66+ self ,
67+ x : Union [List [float ], float , NDArray [np .float64 ]]
68+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
69+ if isinstance (x , np .ndarray ):
70+ return np .array ([self ._compute_pdf (xp ) for xp in x ], dtype = object )
71+ elif isinstance (x , list ):
72+ return [self ._compute_pdf (xp ) for xp in x ]
73+ elif isinstance (x , float ):
74+ return self ._compute_pdf (x )
75+ else :
76+ raise TypeError (f"Unsupported type for x: { type (x )} " )
3777
3878 @abstractmethod
39- def compute_pdf (self , x : float , integrator : Integrator ) -> tuple [float , float ]: ...
79+ def _compute_logpdf (self , x : float ) -> Tuple [float , float ]:
80+ ...
81+
82+ def compute_logpdf (
83+ self ,
84+ x : Union [List [float ], float , NDArray [np .float64 ]]
85+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
86+ if isinstance (x , np .ndarray ):
87+ return np .array ([self ._compute_logpdf (xp ) for xp in x ], dtype = object )
88+ elif isinstance (x , list ):
89+ return [self ._compute_logpdf (xp ) for xp in x ]
90+ elif isinstance (x , float ):
91+ return self ._compute_logpdf (x )
92+ else :
93+ raise TypeError (f"Unsupported type for x: { type (x )} " )
4094
4195 @abstractmethod
42- def compute_logpdf (self , x : float , params : dict ) -> tuple [float , float ]: ...
43-
44- def _params_validation (self , data_collector : Any , params : dict [str , float | rv_continuous | rv_frozen ]) -> Any :
45- """Mixture Parameters Validation
46-
47- Args:
48- data_collector: Dataclass that collect parameters of Mixture
49- params: Input parameters
50-
51- Returns: Instance of dataclass
52-
53- Raises:
54- ValueError: If given parameters is unexpected
55- ValueError: If parameter type is invalid
56- ValueError: If parameters age not given
57-
58- """
59-
96+ def _compute_cdf (self , x : float ) -> Tuple [float , float ]:
97+ ...
98+
99+ def compute_cdf (
100+ self ,
101+ x : Union [List [float ], float , NDArray [np .float64 ]]
102+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
103+ if isinstance (x , np .ndarray ):
104+ return np .array ([self ._compute_cdf (xp ) for xp in x ], dtype = object )
105+ elif isinstance (x , list ):
106+ return [self ._compute_cdf (xp ) for xp in x ]
107+ elif isinstance (x , float ):
108+ return self ._compute_cdf (x )
109+ else :
110+ raise TypeError (f"Unsupported type for x: { type (x )} " )
111+
112+ def _params_validation (
113+ self ,
114+ data_collector : Any ,
115+ params : dict [str , float | rv_continuous | rv_frozen ]
116+ ) -> Any :
117+ """Mixture Parameters Validation"""
60118 dataclass_fields = fields (data_collector )
61119 if len (params ) != len (dataclass_fields ):
62120 raise ValueError (f"Expected { len (dataclass_fields )} arguments, got { len (params )} " )
63- names_and_types = dict ((field .name , field .type ) for field in dataclass_fields )
64- for pair in params .items ():
65- if pair [0 ] not in names_and_types :
66- raise ValueError (f"Unexpected parameter { pair [0 ]} " )
67- if not isinstance (pair [1 ], names_and_types [pair [0 ]]):
68- raise ValueError (f"Type missmatch: { pair [0 ]} should be { names_and_types [pair [0 ]]} , not { type (pair [1 ])} " )
121+ names_and_types = {field .name : field .type for field in dataclass_fields }
122+ for name , value in params .items ():
123+ if name not in names_and_types :
124+ raise ValueError (f"Unexpected parameter { name } " )
125+ if not isinstance (value , names_and_types [name ]):
126+ raise ValueError (
127+ f"Type mismatch: { name } should be { names_and_types [name ]} , not { type (value )} "
128+ )
69129 return data_collector (** params )
0 commit comments