Skip to content

Commit 9e89652

Browse files
committed
feat: add batch support and refactor mixture evaluation logic
- Updated `moment`, `pdf`, `cdf`, and `logpdf` methods to support both scalar and list inputs for efficient batch computation. - Extracted shared evaluation logic (e.g., `moment`, `pdf`, `cdf`, `logpdf`) into the `AbstractMixtures` base class to eliminate code duplication across subclasses. - Centralized input validation, RQMC integration, and distribution handling within the base class for better maintainability. - Fixed type annotations to align with abstract method signatures and resolve `mypy` type-checking issues. These changes improve code clarity, enable vectorized evaluations, and make future extensions easier to implement.
1 parent 2af163d commit 9e89652

File tree

4 files changed

+272
-412
lines changed

4 files changed

+272
-412
lines changed

src/mixtures/abstract_mixture.py

Lines changed: 104 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from abc import ABCMeta, abstractmethod
22
from dataclasses import fields
3-
from typing import Any
3+
from typing import Any, List, Tuple, Union, Dict
4+
import numpy as np
5+
from numpy.typing import NDArray
46

57
from scipy.stats import rv_continuous
68
from scipy.stats.distributions import rv_frozen
@@ -15,7 +17,6 @@ class AbstractMixtures(metaclass=ABCMeta):
1517
@abstractmethod
1618
def __init__(self, mixture_form: str, **kwargs: Any) -> None:
1719
"""
18-
1920
Args:
2021
mixture_form: Form of Mixture classical or Canonical
2122
**kwargs: Parameters of Mixture
@@ -28,40 +29,113 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
2829
raise AssertionError(f"Unknown mixture form: {mixture_form}")
2930

3031
@abstractmethod
31-
def compute_moment(self, n: int, params: dict) -> tuple[float, float]: ...
32+
def _compute_moment(self, n: int, params: Dict) -> Tuple[float, float]:
33+
...
34+
35+
def compute_moment(
36+
self, x: Union[List[int], int, NDArray[np.float64]], params: Dict
37+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
38+
if isinstance(x, np.ndarray):
39+
return np.array([self._compute_moment(xp, params) for xp in x], dtype=object)
40+
elif isinstance(x, list):
41+
return [self._compute_moment(xp, params) for xp in x]
42+
elif isinstance(x, int):
43+
return self._compute_moment(x, params)
44+
else:
45+
raise TypeError(f"Unsupported type for x: {type(x)}")
3246

3347
@abstractmethod
34-
def compute_cdf(self, x: float, params: dict) -> tuple[float, float]: ...
48+
def _compute_pdf(self, x: float, params: Dict) -> Tuple[float, float]:
49+
...
50+
51+
def compute_pdf(
52+
self, x: Union[List[float], float, NDArray[np.float64]], params: Dict
53+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
54+
if isinstance(x, np.ndarray):
55+
return np.array([self._compute_pdf(xp, params) for xp in x], dtype=object)
56+
elif isinstance(x, list):
57+
return [self._compute_pdf(xp, params) for xp in x]
58+
elif isinstance(x, float):
59+
return self._compute_pdf(x, params)
60+
else:
61+
raise TypeError(f"Unsupported type for x: {type(x)}")
3562

3663
@abstractmethod
37-
def compute_pdf(self, x: float, params: dict) -> tuple[float, float]: ...
64+
def _compute_logpdf(self, x: float, params: Dict) -> Tuple[float, float]:
65+
...
66+
67+
def compute_logpdf(
68+
self, x: Union[List[float], float, NDArray[np.float64]], params: Dict
69+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
70+
if isinstance(x, np.ndarray):
71+
return np.array([self._compute_logpdf(xp, params) for xp in x], dtype=object)
72+
elif isinstance(x, list):
73+
return [self._compute_logpdf(xp, params) for xp in x]
74+
elif isinstance(x, float):
75+
return self._compute_logpdf(x, params)
76+
else:
77+
raise TypeError(f"Unsupported type for x: {type(x)}")
3878

3979
@abstractmethod
40-
def compute_logpdf(self, x: float, params: dict) -> tuple[float, float]: ...
41-
42-
def _params_validation(self, data_collector: Any, params: dict[str, float | rv_continuous | rv_frozen]) -> Any:
43-
"""Mixture Parameters Validation
44-
45-
Args:
46-
data_collector: Dataclass that collect parameters of Mixture
47-
params: Input parameters
48-
49-
Returns: Instance of dataclass
50-
51-
Raises:
52-
ValueError: If given parameters is unexpected
53-
ValueError: If parameter type is invalid
54-
ValueError: If parameters age not given
80+
def _compute_cdf(self, x: float, rqmc_params: Dict[str, Any]) -> Tuple[float, float]:
81+
...
82+
83+
def compute_cdf(
84+
self, x: Union[List[float], float, NDArray[np.float64]], params: Dict
85+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
86+
if isinstance(x, np.ndarray):
87+
return np.array([self._compute_cdf(xp, params) for xp in x], dtype=object)
88+
elif isinstance(x, list):
89+
return [self._compute_cdf(xp, params) for xp in x]
90+
elif isinstance(x, float):
91+
return self._compute_cdf(x, params)
92+
else:
93+
raise TypeError(f"Unsupported type for x: {type(x)}")
5594

95+
def _params_validation(mixture_form: str, params: Dict[str, Any], data_collector: Any) -> Any:
96+
"""
97+
Валидация параметров для NormalMeanMixtures.
98+
Проверяет состав параметров, типы, скалярность, знаки.
5699
"""
57100

58-
dataclass_fields = fields(data_collector)
59-
if len(params) != len(dataclass_fields):
60-
raise ValueError(f"Expected {len(dataclass_fields)} arguments, got {len(params)}")
61-
names_and_types = dict((field.name, field.type) for field in dataclass_fields)
62-
for pair in params.items():
63-
if pair[0] not in names_and_types:
64-
raise ValueError(f"Unexpected parameter {pair[0]}")
65-
if not isinstance(pair[1], names_and_types[pair[0]]):
66-
raise ValueError(f"Type missmatch: {pair[0]} should be {names_and_types[pair[0]]}, not {type(pair[1])}")
67-
return data_collector(**params)
101+
# Проверка, что mixture_form корректен
102+
if mixture_form not in ("classical", "canonical"):
103+
raise ValueError(f"Invalid mixture_form '{mixture_form}', expected 'classical' or 'canonical'")
104+
105+
# Проверяем наличие и отсутствие лишних параметров
106+
dataclass_fields = {field.name for field in fields(data_collector)}
107+
params_keys = set(params.keys())
108+
109+
if params_keys != dataclass_fields:
110+
extra = params_keys - dataclass_fields
111+
missing = dataclass_fields - params_keys
112+
msgs = []
113+
if extra:
114+
msgs.append(f"Unexpected parameters: {extra}")
115+
if missing:
116+
msgs.append(f"Missing parameters: {missing}")
117+
raise ValueError(", ".join(msgs))
118+
119+
# Проверяем параметры
120+
for key, val in params.items():
121+
if key == "distribution":
122+
if not (isinstance(val, rv_frozen) or (isinstance(val, type) and issubclass(val, rv_continuous))):
123+
raise ValueError(
124+
f"Parameter 'distribution' must be a scipy.stats distribution class or frozen instance, got {type(val)}"
125+
)
126+
else:
127+
if not np.isscalar(val):
128+
raise ValueError(f"Parameter '{key}' must be a scalar, got {type(val)}")
129+
if not isinstance(val, (int, float)):
130+
raise ValueError(f"Parameter '{key}' must be int or float, got {type(val)}")
131+
132+
# Проверяем специальные условия для параметров:
133+
if mixture_form == "classical":
134+
if key in {"beta", "gamma"} and val <= 0:
135+
raise ValueError(f"Parameter '{key}' must be positive for classical form, got {val}")
136+
elif mixture_form == "canonical":
137+
if key == "sigma" and val <= 0:
138+
raise ValueError(f"Parameter 'sigma' must be positive for canonical form, got {val}")
139+
140+
# Если всё успешно, создаём dataclass
141+
return data_collector(**params)

0 commit comments

Comments
 (0)