Skip to content

Commit a2789f7

Browse files
committed
feat: add batch support and refactor mixture evaluation logic
- Updated `moment`, `pdf`, `cdf`, and `logpdf` methods to support both scalar and list inputs for efficient batch computation. - Extracted shared evaluation logic (e.g., `moment`, `pdf`, `cdf`, `logpdf`) into the `AbstractMixtures` base class to eliminate code duplication across subclasses. - Centralized input validation, RQMC integration, and distribution handling within the base class for better maintainability. - Fixed type annotations to align with abstract method signatures and resolve `mypy` type-checking issues. These changes improve code clarity, enable vectorized evaluations, and make future extensions easier to implement.
1 parent 2af163d commit a2789f7

File tree

4 files changed

+228
-389
lines changed

4 files changed

+228
-389
lines changed

src/mixtures/abstract_mixture.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from abc import ABCMeta, abstractmethod
22
from dataclasses import fields
3-
from typing import Any
3+
from typing import Any, List, Tuple, Union, Dict
4+
import numpy as np
5+
from numpy.typing import NDArray
46

57
from scipy.stats import rv_continuous
68
from scipy.stats.distributions import rv_frozen
@@ -15,7 +17,6 @@ class AbstractMixtures(metaclass=ABCMeta):
1517
@abstractmethod
1618
def __init__(self, mixture_form: str, **kwargs: Any) -> None:
1719
"""
18-
1920
Args:
2021
mixture_form: Form of Mixture classical or Canonical
2122
**kwargs: Parameters of Mixture
@@ -28,16 +29,68 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
2829
raise AssertionError(f"Unknown mixture form: {mixture_form}")
2930

3031
@abstractmethod
31-
def compute_moment(self, n: int, params: dict) -> tuple[float, float]: ...
32+
def _compute_moment(self, n: int, params: Dict) -> Tuple[float, float]:
33+
...
34+
35+
def compute_moment(
36+
self, x: Union[List[int], int, NDArray[np.float64]], params: Dict
37+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
38+
if isinstance(x, np.ndarray):
39+
return np.array([self._compute_moment(xp, params) for xp in x], dtype=object)
40+
elif isinstance(x, list):
41+
return [self._compute_moment(xp, params) for xp in x]
42+
elif isinstance(x, int):
43+
return self._compute_moment(x, params)
44+
else:
45+
raise TypeError(f"Unsupported type for x: {type(x)}")
3246

3347
@abstractmethod
34-
def compute_cdf(self, x: float, params: dict) -> tuple[float, float]: ...
48+
def _compute_pdf(self, x: float, params: Dict) -> Tuple[float, float]:
49+
...
50+
51+
def compute_pdf(
52+
self, x: Union[List[float], float, NDArray[np.float64]], params: Dict
53+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
54+
if isinstance(x, np.ndarray):
55+
return np.array([self._compute_pdf(xp, params) for xp in x], dtype=object)
56+
elif isinstance(x, list):
57+
return [self._compute_pdf(xp, params) for xp in x]
58+
elif isinstance(x, float):
59+
return self._compute_pdf(x, params)
60+
else:
61+
raise TypeError(f"Unsupported type for x: {type(x)}")
3562

3663
@abstractmethod
37-
def compute_pdf(self, x: float, params: dict) -> tuple[float, float]: ...
64+
def _compute_logpdf(self, x: float, params: Dict) -> Tuple[float, float]:
65+
...
66+
67+
def compute_logpdf(
68+
self, x: Union[List[float], float, NDArray[np.float64]], params: Dict
69+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
70+
if isinstance(x, np.ndarray):
71+
return np.array([self._compute_logpdf(xp, params) for xp in x], dtype=object)
72+
elif isinstance(x, list):
73+
return [self._compute_logpdf(xp, params) for xp in x]
74+
elif isinstance(x, float):
75+
return self._compute_logpdf(x, params)
76+
else:
77+
raise TypeError(f"Unsupported type for x: {type(x)}")
3878

3979
@abstractmethod
40-
def compute_logpdf(self, x: float, params: dict) -> tuple[float, float]: ...
80+
def _compute_cdf(self, x: float, rqmc_params: Dict[str, Any]) -> Tuple[float, float]:
81+
...
82+
83+
def compute_cdf(
84+
self, x: Union[List[float], float, NDArray[np.float64]], params: Dict
85+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
86+
if isinstance(x, np.ndarray):
87+
return np.array([self._compute_cdf(xp, params) for xp in x], dtype=object)
88+
elif isinstance(x, list):
89+
return [self._compute_cdf(xp, params) for xp in x]
90+
elif isinstance(x, float):
91+
return self._compute_cdf(x, params)
92+
else:
93+
raise TypeError(f"Unsupported type for x: {type(x)}")
4194

4295
def _params_validation(self, data_collector: Any, params: dict[str, float | rv_continuous | rv_frozen]) -> Any:
4396
"""Mixture Parameters Validation
@@ -64,4 +117,4 @@ def _params_validation(self, data_collector: Any, params: dict[str, float | rv_c
64117
raise ValueError(f"Unexpected parameter {pair[0]}")
65118
if not isinstance(pair[1], names_and_types[pair[0]]):
66119
raise ValueError(f"Type missmatch: {pair[0]} should be {names_and_types[pair[0]]}, not {type(pair[1])}")
67-
return data_collector(**params)
120+
return data_collector(**params)

0 commit comments

Comments
 (0)