11from abc import ABCMeta , abstractmethod
22from dataclasses import fields
3- from typing import Any
3+ from typing import Any , List , Tuple , Union , Dict , Type
4+ import numpy as np
5+ from numpy .typing import NDArray
46
57from scipy .stats import rv_continuous
68from scipy .stats .distributions import rv_frozen
79
10+ from src .algorithms .support_algorithms .integrator import Integrator
11+ from src .algorithms .support_algorithms .rqmc import RQMCIntegrator # default integrator
812
913class AbstractMixtures (metaclass = ABCMeta ):
1014 """Base class for Mixtures"""
1115
1216 _classical_collector : Any
1317 _canonical_collector : Any
1418
15- @abstractmethod
16- def __init__ (self , mixture_form : str , ** kwargs : Any ) -> None :
19+ def __init__ (
20+ self ,
21+ mixture_form : str ,
22+ integrator_cls : Type [Integrator ] = RQMCIntegrator ,
23+ integrator_params : Dict [str , Any ] = None ,
24+ ** kwargs : Any
25+ ) -> None :
1726 """
18-
1927 Args:
20- mixture_form: Form of Mixture classical or Canonical
21- **kwargs: Parameters of Mixture
28+ mixture_form: Form of Mixture classical or canonical
29+ integrator_cls: Class implementing Integrator protocol (default: RQMCIntegrator)
30+ integrator_params: Parameters for integrator constructor (default: {{}})
31+ **kwargs: Parameters of Mixture (alpha, gamma, etc.)
2232 """
33+ self .mixture_form = mixture_form
34+ self .integrator_cls = integrator_cls
35+ self .integrator_params = integrator_params or {}
36+
2337 if mixture_form == "classical" :
2438 self .params = self ._params_validation (self ._classical_collector , kwargs )
2539 elif mixture_form == "canonical" :
@@ -28,40 +42,88 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
2842 raise AssertionError (f"Unknown mixture form: { mixture_form } " )
2943
3044 @abstractmethod
31- def compute_moment (self , n : int , params : dict ) -> tuple [float , float ]: ...
45+ def _compute_moment (self , n : int ) -> Tuple [float , float ]:
46+ ...
47+
48+ def compute_moment (
49+ self ,
50+ x : Union [List [int ], int , NDArray [np .float64 ]]
51+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
52+ if isinstance (x , np .ndarray ):
53+ return np .array ([self ._compute_moment (xp ) for xp in x ], dtype = object )
54+ elif isinstance (x , list ):
55+ return [self ._compute_moment (xp ) for xp in x ]
56+ elif isinstance (x , int ):
57+ return self ._compute_moment (x )
58+ else :
59+ raise TypeError (f"Unsupported type for x: { type (x )} " )
3260
3361 @abstractmethod
34- def compute_cdf (self , x : float , params : dict ) -> tuple [float , float ]: ...
62+ def _compute_pdf (self , x : float ) -> Tuple [float , float ]:
63+ ...
64+
65+ def compute_pdf (
66+ self ,
67+ x : Union [List [float ], float , NDArray [np .float64 ]]
68+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
69+ if isinstance (x , np .ndarray ):
70+ return np .array ([self ._compute_pdf (xp ) for xp in x ], dtype = object )
71+ elif isinstance (x , list ):
72+ return [self ._compute_pdf (xp ) for xp in x ]
73+ elif isinstance (x , float ):
74+ return self ._compute_pdf (x )
75+ else :
76+ raise TypeError (f"Unsupported type for x: { type (x )} " )
3577
3678 @abstractmethod
37- def compute_pdf (self , x : float , params : dict ) -> tuple [float , float ]: ...
79+ def _compute_logpdf (self , x : float ) -> Tuple [float , float ]:
80+ ...
81+
82+ def compute_logpdf (
83+ self ,
84+ x : Union [List [float ], float , NDArray [np .float64 ]]
85+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
86+ if isinstance (x , np .ndarray ):
87+ return np .array ([self ._compute_logpdf (xp ) for xp in x ], dtype = object )
88+ elif isinstance (x , list ):
89+ return [self ._compute_logpdf (xp ) for xp in x ]
90+ elif isinstance (x , float ):
91+ return self ._compute_logpdf (x )
92+ else :
93+ raise TypeError (f"Unsupported type for x: { type (x )} " )
3894
3995 @abstractmethod
40- def compute_logpdf (self , x : float , params : dict ) -> tuple [float , float ]: ...
41-
42- def _params_validation (self , data_collector : Any , params : dict [str , float | rv_continuous | rv_frozen ]) -> Any :
43- """Mixture Parameters Validation
44-
45- Args:
46- data_collector: Dataclass that collect parameters of Mixture
47- params: Input parameters
48-
49- Returns: Instance of dataclass
50-
51- Raises:
52- ValueError: If given parameters is unexpected
53- ValueError: If parameter type is invalid
54- ValueError: If parameters age not given
55-
56- """
57-
96+ def _compute_cdf (self , x : float ) -> Tuple [float , float ]:
97+ ...
98+
99+ def compute_cdf (
100+ self ,
101+ x : Union [List [float ], float , NDArray [np .float64 ]]
102+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
103+ if isinstance (x , np .ndarray ):
104+ return np .array ([self ._compute_cdf (xp ) for xp in x ], dtype = object )
105+ elif isinstance (x , list ):
106+ return [self ._compute_cdf (xp ) for xp in x ]
107+ elif isinstance (x , float ):
108+ return self ._compute_cdf (x )
109+ else :
110+ raise TypeError (f"Unsupported type for x: { type (x )} " )
111+
112+ def _params_validation (
113+ self ,
114+ data_collector : Any ,
115+ params : dict [str , float | rv_continuous | rv_frozen ]
116+ ) -> Any :
117+ """Mixture Parameters Validation"""
58118 dataclass_fields = fields (data_collector )
59119 if len (params ) != len (dataclass_fields ):
60120 raise ValueError (f"Expected { len (dataclass_fields )} arguments, got { len (params )} " )
61- names_and_types = dict ((field .name , field .type ) for field in dataclass_fields )
62- for pair in params .items ():
63- if pair [0 ] not in names_and_types :
64- raise ValueError (f"Unexpected parameter { pair [0 ]} " )
65- if not isinstance (pair [1 ], names_and_types [pair [0 ]]):
66- raise ValueError (f"Type missmatch: { pair [0 ]} should be { names_and_types [pair [0 ]]} , not { type (pair [1 ])} " )
121+ names_and_types = {field .name : field .type for field in dataclass_fields }
122+ for name , value in params .items ():
123+ if name not in names_and_types :
124+ raise ValueError (f"Unexpected parameter { name } " )
125+ if not isinstance (value , names_and_types [name ]):
126+ raise ValueError (
127+ f"Type mismatch: { name } should be { names_and_types [name ]} , not { type (value )} "
128+ )
67129 return data_collector (** params )
0 commit comments