Skip to content

Commit 0e290e1

Browse files
committed
feat: add batch support and refactor mixture evaluation logic
- Updated `moment`, `pdf`, `cdf`, and `logpdf` methods to support both scalar and list inputs for efficient batch computation. - Extracted shared evaluation logic (e.g., `moment`, `pdf`, `cdf`, `logpdf`) into the `AbstractMixtures` base class to eliminate code duplication across subclasses. - Centralized input validation, RQMC integration, and distribution handling within the base class for better maintainability. - Fixed type annotations to align with abstract method signatures and resolve `mypy` type-checking issues. These changes improve code clarity, enable vectorized evaluations, and make future extensions easier to implement.
1 parent 2af163d commit 0e290e1

File tree

4 files changed

+259
-409
lines changed

4 files changed

+259
-409
lines changed

src/mixtures/abstract_mixture.py

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from abc import ABCMeta, abstractmethod
22
from dataclasses import fields
3-
from typing import Any
3+
from typing import Any, List, Tuple, Union, Dict
4+
import numpy as np
5+
from numpy.typing import NDArray
46

57
from scipy.stats import rv_continuous
68
from scipy.stats.distributions import rv_frozen
@@ -15,7 +17,6 @@ class AbstractMixtures(metaclass=ABCMeta):
1517
@abstractmethod
1618
def __init__(self, mixture_form: str, **kwargs: Any) -> None:
1719
"""
18-
1920
Args:
2021
mixture_form: Form of Mixture classical or Canonical
2122
**kwargs: Parameters of Mixture
@@ -28,40 +29,103 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
2829
raise AssertionError(f"Unknown mixture form: {mixture_form}")
2930

3031
@abstractmethod
31-
def compute_moment(self, n: int, params: dict) -> tuple[float, float]: ...
32+
def _compute_moment(self, n: int, params: Dict) -> Tuple[float, float]:
33+
...
34+
35+
def compute_moment(
36+
self, x: Union[List[int], int, NDArray[np.float64]], params: Dict
37+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
38+
if isinstance(x, np.ndarray):
39+
return np.array([self._compute_moment(xp, params) for xp in x], dtype=object)
40+
elif isinstance(x, list):
41+
return [self._compute_moment(xp, params) for xp in x]
42+
elif isinstance(x, int):
43+
return self._compute_moment(x, params)
44+
else:
45+
raise TypeError(f"Unsupported type for x: {type(x)}")
3246

3347
@abstractmethod
34-
def compute_cdf(self, x: float, params: dict) -> tuple[float, float]: ...
48+
def _compute_pdf(self, x: float, params: Dict) -> Tuple[float, float]:
49+
...
50+
51+
def compute_pdf(
52+
self, x: Union[List[float], float, NDArray[np.float64]], params: Dict
53+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
54+
if isinstance(x, np.ndarray):
55+
return np.array([self._compute_pdf(xp, params) for xp in x], dtype=object)
56+
elif isinstance(x, list):
57+
return [self._compute_pdf(xp, params) for xp in x]
58+
elif isinstance(x, float):
59+
return self._compute_pdf(x, params)
60+
else:
61+
raise TypeError(f"Unsupported type for x: {type(x)}")
3562

3663
@abstractmethod
37-
def compute_pdf(self, x: float, params: dict) -> tuple[float, float]: ...
64+
def _compute_logpdf(self, x: float, params: Dict) -> Tuple[float, float]:
65+
...
66+
67+
def compute_logpdf(
68+
self, x: Union[List[float], float, NDArray[np.float64]], params: Dict
69+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
70+
if isinstance(x, np.ndarray):
71+
return np.array([self._compute_logpdf(xp, params) for xp in x], dtype=object)
72+
elif isinstance(x, list):
73+
return [self._compute_logpdf(xp, params) for xp in x]
74+
elif isinstance(x, float):
75+
return self._compute_logpdf(x, params)
76+
else:
77+
raise TypeError(f"Unsupported type for x: {type(x)}")
3878

3979
@abstractmethod
40-
def compute_logpdf(self, x: float, params: dict) -> tuple[float, float]: ...
80+
def _compute_cdf(self, x: float, rqmc_params: Dict[str, Any]) -> Tuple[float, float]:
81+
...
82+
83+
def compute_cdf(
84+
self, x: Union[List[float], float, NDArray[np.float64]], params: Dict
85+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
86+
if isinstance(x, np.ndarray):
87+
return np.array([self._compute_cdf(xp, params) for xp in x], dtype=object)
88+
elif isinstance(x, list):
89+
return [self._compute_cdf(xp, params) for xp in x]
90+
elif isinstance(x, float):
91+
return self._compute_cdf(x, params)
92+
else:
93+
raise TypeError(f"Unsupported type for x: {type(x)}")
4194

42-
def _params_validation(self, data_collector: Any, params: dict[str, float | rv_continuous | rv_frozen]) -> Any:
43-
"""Mixture Parameters Validation
95+
def _params_validation(
96+
self,
97+
data_collector: Any,
98+
params: Dict[str, Union[float, rv_continuous, rv_frozen]],
99+
) -> Any:
100+
"""Mixture Parameters Validation"""
44101

45-
Args:
46-
data_collector: Dataclass that collect parameters of Mixture
47-
params: Input parameters
102+
dataclass_fields = fields(data_collector)
103+
if len(params) != len(dataclass_fields):
104+
raise ValueError(f"Expected {len(dataclass_fields)} arguments, got {len(params)}")
48105

49-
Returns: Instance of dataclass
106+
names_and_types = {field.name: field.type for field in dataclass_fields}
50107

51-
Raises:
52-
ValueError: If given parameters is unexpected
53-
ValueError: If parameter type is invalid
54-
ValueError: If parameters age not given
108+
for key, value in params.items():
109+
if key not in names_and_types:
110+
raise ValueError(f"Unexpected parameter {key}")
55111

56-
"""
112+
# Проверка, что значение не список или массив (т.е. скаляр)
113+
if not np.isscalar(value):
114+
raise ValueError(f"Parameter '{key}' must be a scalar value, got {type(value)}")
57115

58-
dataclass_fields = fields(data_collector)
59-
if len(params) != len(dataclass_fields):
60-
raise ValueError(f"Expected {len(dataclass_fields)} arguments, got {len(params)}")
61-
names_and_types = dict((field.name, field.type) for field in dataclass_fields)
62-
for pair in params.items():
63-
if pair[0] not in names_and_types:
64-
raise ValueError(f"Unexpected parameter {pair[0]}")
65-
if not isinstance(pair[1], names_and_types[pair[0]]):
66-
raise ValueError(f"Type missmatch: {pair[0]} should be {names_and_types[pair[0]]}, not {type(pair[1])}")
67-
return data_collector(**params)
116+
# Проверка типа — либо число (int или float), либо распределение (rv_continuous/rv_frozen)
117+
expected_type = names_and_types[key]
118+
119+
if expected_type in [float, int]:
120+
if not isinstance(value, (int, float)):
121+
raise ValueError(f"Parameter '{key}' must be int or float, got {type(value)}")
122+
123+
elif expected_type in [rv_continuous, rv_frozen]:
124+
# Проверяем, что это объект распределения scipy.stats
125+
if not (hasattr(value, "pdf") and callable(value.pdf)):
126+
raise ValueError(f"Parameter '{key}' must be a scipy.stats distribution object")
127+
128+
else:
129+
pass
130+
131+
return data_collector(**params)

0 commit comments

Comments
 (0)