11from abc import ABCMeta , abstractmethod
22from dataclasses import fields
3- from typing import Any
3+ from typing import Any , List , Tuple , Union , Dict
4+ import numpy as np
5+ from numpy .typing import NDArray
46
57from scipy .stats import rv_continuous
68from scipy .stats .distributions import rv_frozen
@@ -15,7 +17,6 @@ class AbstractMixtures(metaclass=ABCMeta):
1517 @abstractmethod
1618 def __init__ (self , mixture_form : str , ** kwargs : Any ) -> None :
1719 """
18-
1920 Args:
2021 mixture_form: Form of Mixture classical or Canonical
2122 **kwargs: Parameters of Mixture
@@ -28,40 +29,103 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
2829 raise AssertionError (f"Unknown mixture form: { mixture_form } " )
2930
3031 @abstractmethod
31- def compute_moment (self , n : int , params : dict ) -> tuple [float , float ]: ...
32+ def _compute_moment (self , n : int , params : Dict ) -> Tuple [float , float ]:
33+ ...
34+
35+ def compute_moment (
36+ self , x : Union [List [int ], int , NDArray [np .float64 ]], params : Dict
37+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
38+ if isinstance (x , np .ndarray ):
39+ return np .array ([self ._compute_moment (xp , params ) for xp in x ], dtype = object )
40+ elif isinstance (x , list ):
41+ return [self ._compute_moment (xp , params ) for xp in x ]
42+ elif isinstance (x , int ):
43+ return self ._compute_moment (x , params )
44+ else :
45+ raise TypeError (f"Unsupported type for x: { type (x )} " )
3246
3347 @abstractmethod
34- def compute_cdf (self , x : float , params : dict ) -> tuple [float , float ]: ...
48+ def _compute_pdf (self , x : float , params : Dict ) -> Tuple [float , float ]:
49+ ...
50+
51+ def compute_pdf (
52+ self , x : Union [List [float ], float , NDArray [np .float64 ]], params : Dict
53+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
54+ if isinstance (x , np .ndarray ):
55+ return np .array ([self ._compute_pdf (xp , params ) for xp in x ], dtype = object )
56+ elif isinstance (x , list ):
57+ return [self ._compute_pdf (xp , params ) for xp in x ]
58+ elif isinstance (x , float ):
59+ return self ._compute_pdf (x , params )
60+ else :
61+ raise TypeError (f"Unsupported type for x: { type (x )} " )
3562
3663 @abstractmethod
37- def compute_pdf (self , x : float , params : dict ) -> tuple [float , float ]: ...
64+ def _compute_logpdf (self , x : float , params : Dict ) -> Tuple [float , float ]:
65+ ...
66+
67+ def compute_logpdf (
68+ self , x : Union [List [float ], float , NDArray [np .float64 ]], params : Dict
69+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
70+ if isinstance (x , np .ndarray ):
71+ return np .array ([self ._compute_logpdf (xp , params ) for xp in x ], dtype = object )
72+ elif isinstance (x , list ):
73+ return [self ._compute_logpdf (xp , params ) for xp in x ]
74+ elif isinstance (x , float ):
75+ return self ._compute_logpdf (x , params )
76+ else :
77+ raise TypeError (f"Unsupported type for x: { type (x )} " )
3878
3979 @abstractmethod
40- def compute_logpdf (self , x : float , params : dict ) -> tuple [float , float ]: ...
80+ def _compute_cdf (self , x : float , rqmc_params : Dict [str , Any ]) -> Tuple [float , float ]:
81+ ...
82+
83+ def compute_cdf (
84+ self , x : Union [List [float ], float , NDArray [np .float64 ]], params : Dict
85+ ) -> Union [List [Tuple [float , float ]], Tuple [float , float ], NDArray [Any ]]:
86+ if isinstance (x , np .ndarray ):
87+ return np .array ([self ._compute_cdf (xp , params ) for xp in x ], dtype = object )
88+ elif isinstance (x , list ):
89+ return [self ._compute_cdf (xp , params ) for xp in x ]
90+ elif isinstance (x , float ):
91+ return self ._compute_cdf (x , params )
92+ else :
93+ raise TypeError (f"Unsupported type for x: { type (x )} " )
4194
42- def _params_validation (self , data_collector : Any , params : dict [str , float | rv_continuous | rv_frozen ]) -> Any :
43- """Mixture Parameters Validation
95+ def _params_validation (
96+ self ,
97+ data_collector : Any ,
98+ params : Dict [str , Union [float , rv_continuous , rv_frozen ]],
99+ ) -> Any :
100+ """Mixture Parameters Validation"""
44101
45- Args:
46- data_collector: Dataclass that collect parameters of Mixture
47- params: Input parameters
102+ dataclass_fields = fields ( data_collector )
103+ if len ( params ) != len ( dataclass_fields ):
104+ raise ValueError ( f"Expected { len ( dataclass_fields ) } arguments, got { len ( params ) } " )
48105
49- Returns: Instance of dataclass
106+ names_and_types = { field . name : field . type for field in dataclass_fields }
50107
51- Raises:
52- ValueError: If given parameters is unexpected
53- ValueError: If parameter type is invalid
54- ValueError: If parameters age not given
108+ for key , value in params .items ():
109+ if key not in names_and_types :
110+ raise ValueError (f"Unexpected parameter { key } " )
55111
56- """
112+ # Проверка, что значение не список или массив (т.е. скаляр)
113+ if not np .isscalar (value ):
114+ raise ValueError (f"Parameter '{ key } ' must be a scalar value, got { type (value )} " )
57115
58- dataclass_fields = fields (data_collector )
59- if len (params ) != len (dataclass_fields ):
60- raise ValueError (f"Expected { len (dataclass_fields )} arguments, got { len (params )} " )
61- names_and_types = dict ((field .name , field .type ) for field in dataclass_fields )
62- for pair in params .items ():
63- if pair [0 ] not in names_and_types :
64- raise ValueError (f"Unexpected parameter { pair [0 ]} " )
65- if not isinstance (pair [1 ], names_and_types [pair [0 ]]):
66- raise ValueError (f"Type missmatch: { pair [0 ]} should be { names_and_types [pair [0 ]]} , not { type (pair [1 ])} " )
67- return data_collector (** params )
116+ # Проверка типа — либо число (int или float), либо распределение (rv_continuous/rv_frozen)
117+ expected_type = names_and_types [key ]
118+
119+ if expected_type in [float , int ]:
120+ if not isinstance (value , (int , float )):
121+ raise ValueError (f"Parameter '{ key } ' must be int or float, got { type (value )} " )
122+
123+ elif expected_type in [rv_continuous , rv_frozen ]:
124+ # Проверяем, что это объект распределения scipy.stats
125+ if not (hasattr (value , "pdf" ) and callable (value .pdf )):
126+ raise ValueError (f"Parameter '{ key } ' must be a scipy.stats distribution object" )
127+
128+ else :
129+ pass
130+
131+ return data_collector (** params )
0 commit comments