11from abc import ABCMeta , abstractmethod
22from dataclasses import fields
3- from typing import Any
3+ from typing import Any , List , Tuple , Union , Dict
4+ import numpy as np
5+ from numpy .typing import NDArray
46
57from scipy .stats import rv_continuous
68from scipy .stats .distributions import rv_frozen
@@ -15,7 +17,6 @@ class AbstractMixtures(metaclass=ABCMeta):
1517 @abstractmethod
1618 def __init__ (self , mixture_form : str , ** kwargs : Any ) -> None :
1719 """
18-
1920 Args:
2021 mixture_form: Form of Mixture classical or Canonical
2122 **kwargs: Parameters of Mixture
@@ -28,40 +29,113 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
2829 raise AssertionError (f"Unknown mixture form: { mixture_form } " )
2930
3031 @abstractmethod
31- def compute_moment (self , n : int , params : dict ) -> tuple [float , float ]: ...
32+ def _compute_moment (self , n : int , params : Dict ) -> Tuple [float , float ]:
33+ ...
34+
35+ def compute_moment (
36+ self , x : Union [List [int ], int , NDArray [np .float64 ]], params : Dict
37+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
38+ if isinstance (x , np .ndarray ):
39+ return np .array ([self ._compute_moment (xp , params ) for xp in x ], dtype = object )
40+ elif isinstance (x , list ):
41+ return [self ._compute_moment (xp , params ) for xp in x ]
42+ elif isinstance (x , int ):
43+ return self ._compute_moment (x , params )
44+ else :
45+ raise TypeError (f"Unsupported type for x: { type (x )} " )
3246
3347 @abstractmethod
34- def compute_cdf (self , x : float , params : dict ) -> tuple [float , float ]: ...
48+ def _compute_pdf (self , x : float , params : Dict ) -> Tuple [float , float ]:
49+ ...
50+
51+ def compute_pdf (
52+ self , x : Union [List [float ], float , NDArray [np .float64 ]], params : Dict
53+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
54+ if isinstance (x , np .ndarray ):
55+ return np .array ([self ._compute_pdf (xp , params ) for xp in x ], dtype = object )
56+ elif isinstance (x , list ):
57+ return [self ._compute_pdf (xp , params ) for xp in x ]
58+ elif isinstance (x , float ):
59+ return self ._compute_pdf (x , params )
60+ else :
61+ raise TypeError (f"Unsupported type for x: { type (x )} " )
3562
3663 @abstractmethod
37- def compute_pdf (self , x : float , params : dict ) -> tuple [float , float ]: ...
64+ def _compute_logpdf (self , x : float , params : Dict ) -> Tuple [float , float ]:
65+ ...
66+
67+ def compute_logpdf (
68+ self , x : Union [List [float ], float , NDArray [np .float64 ]], params : Dict
69+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
70+ if isinstance (x , np .ndarray ):
71+ return np .array ([self ._compute_logpdf (xp , params ) for xp in x ], dtype = object )
72+ elif isinstance (x , list ):
73+ return [self ._compute_logpdf (xp , params ) for xp in x ]
74+ elif isinstance (x , float ):
75+ return self ._compute_logpdf (x , params )
76+ else :
77+ raise TypeError (f"Unsupported type for x: { type (x )} " )
3878
3979 @abstractmethod
40- def compute_logpdf (self , x : float , params : dict ) -> tuple [float , float ]: ...
41-
42- def _params_validation (self , data_collector : Any , params : dict [str , float | rv_continuous | rv_frozen ]) -> Any :
43- """Mixture Parameters Validation
44-
45- Args:
46- data_collector: Dataclass that collect parameters of Mixture
47- params: Input parameters
48-
49- Returns: Instance of dataclass
50-
51- Raises:
52- ValueError: If given parameters is unexpected
53- ValueError: If parameter type is invalid
54- ValueError: If parameters age not given
80+ def _compute_cdf (self , x : float , rqmc_params : Dict [str , Any ]) -> Tuple [float , float ]:
81+ ...
82+
83+ def compute_cdf (
84+ self , x : Union [List [float ], float , NDArray [np .float64 ]], params : Dict
85+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
86+ if isinstance (x , np .ndarray ):
87+ return np .array ([self ._compute_cdf (xp , params ) for xp in x ], dtype = object )
88+ elif isinstance (x , list ):
89+ return [self ._compute_cdf (xp , params ) for xp in x ]
90+ elif isinstance (x , float ):
91+ return self ._compute_cdf (x , params )
92+ else :
93+ raise TypeError (f"Unsupported type for x: { type (x )} " )
5594
95+ def _params_validation (mixture_form : str , params : Dict [str , Any ], data_collector : Any ) -> Any :
96+ """
97+ Валидация параметров для NormalMeanMixtures.
98+ Проверяет состав параметров, типы, скалярность, знаки.
5699 """
57100
58- dataclass_fields = fields (data_collector )
59- if len (params ) != len (dataclass_fields ):
60- raise ValueError (f"Expected { len (dataclass_fields )} arguments, got { len (params )} " )
61- names_and_types = dict ((field .name , field .type ) for field in dataclass_fields )
62- for pair in params .items ():
63- if pair [0 ] not in names_and_types :
64- raise ValueError (f"Unexpected parameter { pair [0 ]} " )
65- if not isinstance (pair [1 ], names_and_types [pair [0 ]]):
66- raise ValueError (f"Type missmatch: { pair [0 ]} should be { names_and_types [pair [0 ]]} , not { type (pair [1 ])} " )
67- return data_collector (** params )
101+ # Проверка, что mixture_form корректен
102+ if mixture_form not in ("classical" , "canonical" ):
103+ raise ValueError (f"Invalid mixture_form '{ mixture_form } ', expected 'classical' or 'canonical'" )
104+
105+ # Проверяем наличие и отсутствие лишних параметров
106+ dataclass_fields = {field .name for field in fields (data_collector )}
107+ params_keys = set (params .keys ())
108+
109+ if params_keys != dataclass_fields :
110+ extra = params_keys - dataclass_fields
111+ missing = dataclass_fields - params_keys
112+ msgs = []
113+ if extra :
114+ msgs .append (f"Unexpected parameters: { extra } " )
115+ if missing :
116+ msgs .append (f"Missing parameters: { missing } " )
117+ raise ValueError (", " .join (msgs ))
118+
119+ # Проверяем параметры
120+ for key , val in params .items ():
121+ if key == "distribution" :
122+ if not (isinstance (val , rv_frozen ) or (isinstance (val , type ) and issubclass (val , rv_continuous ))):
123+ raise ValueError (
124+ f"Parameter 'distribution' must be a scipy.stats distribution class or frozen instance, got { type (val )} "
125+ )
126+ else :
127+ if not np .isscalar (val ):
128+ raise ValueError (f"Parameter '{ key } ' must be a scalar, got { type (val )} " )
129+ if not isinstance (val , (int , float )):
130+ raise ValueError (f"Parameter '{ key } ' must be int or float, got { type (val )} " )
131+
132+ # Проверяем специальные условия для параметров:
133+ if mixture_form == "classical" :
134+ if key in {"beta" , "gamma" } and val <= 0 :
135+ raise ValueError (f"Parameter '{ key } ' must be positive for classical form, got { val } " )
136+ elif mixture_form == "canonical" :
137+ if key == "sigma" and val <= 0 :
138+ raise ValueError (f"Parameter 'sigma' must be positive for canonical form, got { val } " )
139+
140+ # Если всё успешно, создаём dataclass
141+ return data_collector (** params )
0 commit comments