Skip to content

Commit 1cebc58

Browse files
authored
Merge pull request #46 from PySATL/feat/unify-mixture
feat: add batch support and refactor mixture evaluation logic
2 parents 8ec30a0 + 5f7295c commit 1cebc58

File tree

4 files changed

+307
-475
lines changed

4 files changed

+307
-475
lines changed

src/mixtures/abstract_mixture.py

Lines changed: 93 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
11
from abc import ABCMeta, abstractmethod
22
from dataclasses import fields
3-
from typing import Any
3+
from typing import Any, List, Tuple, Union, Dict, Type
4+
import numpy as np
5+
from numpy.typing import NDArray
46

57
from scipy.stats import rv_continuous
68
from scipy.stats.distributions import rv_frozen
79

810
from src.algorithms.support_algorithms.integrator import Integrator
11+
from src.algorithms.support_algorithms.rqmc import RQMCIntegrator # default integrator
912

1013
class AbstractMixtures(metaclass=ABCMeta):
1114
"""Base class for Mixtures"""
1215

1316
_classical_collector: Any
1417
_canonical_collector: Any
1518

16-
@abstractmethod
17-
def __init__(self, mixture_form: str, **kwargs: Any) -> None:
19+
def __init__(
20+
self,
21+
mixture_form: str,
22+
integrator_cls: Type[Integrator] = RQMCIntegrator,
23+
integrator_params: Dict[str, Any] = None,
24+
**kwargs: Any
25+
) -> None:
1826
"""
19-
2027
Args:
21-
mixture_form: Form of Mixture classical or Canonical
22-
**kwargs: Parameters of Mixture
28+
mixture_form: Form of Mixture classical or canonical
29+
integrator_cls: Class implementing Integrator protocol (default: RQMCIntegrator)
30+
integrator_params: Parameters for integrator constructor (default: {{}})
31+
**kwargs: Parameters of Mixture (alpha, gamma, etc.)
2332
"""
2433
self.mixture_form = mixture_form
34+
self.integrator_cls = integrator_cls
35+
self.integrator_params = integrator_params or {}
36+
2537
if mixture_form == "classical":
2638
self.params = self._params_validation(self._classical_collector, kwargs)
2739
elif mixture_form == "canonical":
@@ -30,40 +42,88 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
3042
raise AssertionError(f"Unknown mixture form: {mixture_form}")
3143

3244
@abstractmethod
33-
def compute_moment(self, n: int, integrator: Integrator) -> tuple[float, float]: ...
45+
def _compute_moment(self, n: int) -> Tuple[float, float]:
46+
...
47+
48+
def compute_moment(
49+
self,
50+
x: Union[List[int], int, NDArray[np.float64]]
51+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
52+
if isinstance(x, np.ndarray):
53+
return np.array([self._compute_moment(xp) for xp in x], dtype=object)
54+
elif isinstance(x, list):
55+
return [self._compute_moment(xp) for xp in x]
56+
elif isinstance(x, int):
57+
return self._compute_moment(x)
58+
else:
59+
raise TypeError(f"Unsupported type for x: {type(x)}")
3460

3561
@abstractmethod
36-
def compute_cdf(self, x: float, integrator: Integrator) -> tuple[float, float]: ...
62+
def _compute_pdf(self, x: float) -> Tuple[float, float]:
63+
...
64+
65+
def compute_pdf(
66+
self,
67+
x: Union[List[float], float, NDArray[np.float64]]
68+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
69+
if isinstance(x, np.ndarray):
70+
return np.array([self._compute_pdf(xp) for xp in x], dtype=object)
71+
elif isinstance(x, list):
72+
return [self._compute_pdf(xp) for xp in x]
73+
elif isinstance(x, float):
74+
return self._compute_pdf(x)
75+
else:
76+
raise TypeError(f"Unsupported type for x: {type(x)}")
3777

3878
@abstractmethod
39-
def compute_pdf(self, x: float, integrator: Integrator) -> tuple[float, float]: ...
79+
def _compute_logpdf(self, x: float) -> Tuple[float, float]:
80+
...
81+
82+
def compute_logpdf(
83+
self,
84+
x: Union[List[float], float, NDArray[np.float64]]
85+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
86+
if isinstance(x, np.ndarray):
87+
return np.array([self._compute_logpdf(xp) for xp in x], dtype=object)
88+
elif isinstance(x, list):
89+
return [self._compute_logpdf(xp) for xp in x]
90+
elif isinstance(x, float):
91+
return self._compute_logpdf(x)
92+
else:
93+
raise TypeError(f"Unsupported type for x: {type(x)}")
4094

4195
@abstractmethod
42-
def compute_logpdf(self, x: float, params: dict) -> tuple[float, float]: ...
43-
44-
def _params_validation(self, data_collector: Any, params: dict[str, float | rv_continuous | rv_frozen]) -> Any:
45-
"""Mixture Parameters Validation
46-
47-
Args:
48-
data_collector: Dataclass that collect parameters of Mixture
49-
params: Input parameters
50-
51-
Returns: Instance of dataclass
52-
53-
Raises:
54-
ValueError: If given parameters is unexpected
55-
ValueError: If parameter type is invalid
56-
ValueError: If parameters age not given
57-
58-
"""
59-
96+
def _compute_cdf(self, x: float) -> Tuple[float, float]:
97+
...
98+
99+
def compute_cdf(
100+
self,
101+
x: Union[List[float], float, NDArray[np.float64]]
102+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
103+
if isinstance(x, np.ndarray):
104+
return np.array([self._compute_cdf(xp) for xp in x], dtype=object)
105+
elif isinstance(x, list):
106+
return [self._compute_cdf(xp) for xp in x]
107+
elif isinstance(x, float):
108+
return self._compute_cdf(x)
109+
else:
110+
raise TypeError(f"Unsupported type for x: {type(x)}")
111+
112+
def _params_validation(
113+
self,
114+
data_collector: Any,
115+
params: dict[str, float | rv_continuous | rv_frozen]
116+
) -> Any:
117+
"""Mixture Parameters Validation"""
60118
dataclass_fields = fields(data_collector)
61119
if len(params) != len(dataclass_fields):
62120
raise ValueError(f"Expected {len(dataclass_fields)} arguments, got {len(params)}")
63-
names_and_types = dict((field.name, field.type) for field in dataclass_fields)
64-
for pair in params.items():
65-
if pair[0] not in names_and_types:
66-
raise ValueError(f"Unexpected parameter {pair[0]}")
67-
if not isinstance(pair[1], names_and_types[pair[0]]):
68-
raise ValueError(f"Type missmatch: {pair[0]} should be {names_and_types[pair[0]]}, not {type(pair[1])}")
121+
names_and_types = {field.name: field.type for field in dataclass_fields}
122+
for name, value in params.items():
123+
if name not in names_and_types:
124+
raise ValueError(f"Unexpected parameter {name}")
125+
if not isinstance(value, names_and_types[name]):
126+
raise ValueError(
127+
f"Type mismatch: {name} should be {names_and_types[name]}, not {type(value)}"
128+
)
69129
return data_collector(**params)

0 commit comments

Comments
 (0)