Skip to content

Commit 476758e

Browse files
authored
Merge pull request #40 from PySATL/api/abstract-integrator
feat: dynamic integrator selection
2 parents 2af163d + 51a7303 commit 476758e

File tree

8 files changed

+244
-112
lines changed

8 files changed

+244
-112
lines changed

src/algorithms/semiparam_algorithms/nvm_semi_param_algorithms/g_estimation_given_mu_rqmc_based.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from scipy.integrate import quad_vec
88
from scipy.special import gamma
99

10-
from src.algorithms.support_algorithms.rqmc import RQMC
10+
from src.algorithms.support_algorithms.integrator import Integrator
11+
from src.algorithms.support_algorithms.rqmc import RQMCIntegrator
1112
from src.estimators.estimate_result import EstimateResult
1213

1314
MU_DEFAULT_VALUE = 1.0
@@ -68,7 +69,7 @@ def __init__(self, sample: Optional[_typing.NDArray[np.float64]] = None, **kwarg
6869
self.x_data,
6970
self.grid_size,
7071
self.integration_tolerance,
71-
self.integration_limit,
72+
self.integration_limit
7273
) = self._validate_kwargs(self.n, **kwargs)
7374
self.denominator: float = 2 * math.pi * self.n
7475
self.precompute_gamma_grid()
@@ -162,15 +163,15 @@ def second_v_integrand(self, v: float, x: float) -> np.ndarray:
162163
x_power = self.x_powers[x][idx]
163164
return (self.second_u_integrals[idx] * x_power) / gamma_val
164165

165-
def compute_integrals_for_x(self, x: float) -> float:
166+
def compute_integrals_for_x(self, x: float, integrator: Integrator = RQMCIntegrator()) -> float:
166167
"""Compute integrals using RQMC for v-integration."""
167-
first_integral = RQMC(lambda t: np.sum(self.first_v_integrand(t * self.v_value, x)) * self.v_value).rqmc()[0]
168+
first_integral = integrator.compute(func=lambda t: np.sum(self.first_v_integrand(t * self.v_value, x)) * self.v_value).value
168169

169-
second_integral = RQMC(lambda t: np.sum(self.second_v_integrand(-t * self.v_value, x)) * self.v_value).rqmc()[0]
170+
second_integral = integrator.compute(func=lambda t: np.sum(self.second_v_integrand(-t * self.v_value, x)) * self.v_value).value
170171

171172
total = (first_integral + second_integral) / self.denominator
172173
return max(0.0, total.real)
173174

174175
def algorithm(self, sample: np._typing.NDArray) -> EstimateResult:
175-
y_data = [self.compute_integrals_for_x(x) for x in self.x_data]
176+
y_data = [self.compute_integrals_for_x(x, RQMCIntegrator()) for x in self.x_data]
176177
return EstimateResult(list_value=y_data, success=True)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Protocol, Callable, Optional
3+
4+
5+
@dataclass
6+
class IntegrationResult:
7+
value: float
8+
error: float
9+
message: Optional[dict[str, Any]] | None = None
10+
11+
12+
class Integrator(Protocol):
13+
14+
"""Base class for integral calculation"""
15+
16+
def __init__(self) -> None:
17+
...
18+
19+
def compute(self, func: Callable) -> IntegrationResult:
20+
...
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import Callable, Any, Sequence
2+
3+
from scipy.integrate import quad
4+
from src.algorithms.support_algorithms.integrator import IntegrationResult
5+
6+
class QuadIntegrator:
7+
8+
def __init__(
9+
self,
10+
a: float = 0,
11+
b: float = 1,
12+
args: tuple[Any, ...] = (),
13+
full_output: int = 0,
14+
epsabs: float | int = 1.49e-08,
15+
epsrel: float | int = 1.49e-08,
16+
limit: float | int = 50,
17+
points: Sequence[float | int] | None = None,
18+
weight: float | int | None = None,
19+
wvar: Any = None,
20+
wopts: Any = None,
21+
maxp1: float | int = 50,
22+
limlst: int = 50,
23+
complex_func: bool = False,
24+
):
25+
self.params = {
26+
'a': a,
27+
'b': b,
28+
'args': args,
29+
'full_output': full_output,
30+
'epsabs': epsabs,
31+
'epsrel': epsrel,
32+
'limit': limit,
33+
'points': points,
34+
'weight': weight,
35+
'wvar': wvar,
36+
'wopts': wopts,
37+
'maxp1': maxp1,
38+
'limlst': limlst,
39+
'complex_func': complex_func
40+
}
41+
42+
def compute(self, func: Callable) -> IntegrationResult:
43+
44+
"""
45+
Compute integral via quad integrator
46+
47+
Args:
48+
func: integrated function
49+
50+
Returns: moment approximation and error tolerance
51+
"""
52+
53+
verbose = self.params.pop('full_output', False)
54+
result = quad(func, **self.params)
55+
if verbose:
56+
value, error, message = result
57+
else:
58+
value, error = result
59+
message = None
60+
return IntegrationResult(value, error, message)

src/algorithms/support_algorithms/rqmc.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import scipy
66
from numba import njit
77

8+
from src.algorithms.support_algorithms.integrator import IntegrationResult
9+
810
BITS = 30
911
"""Number of bits in XOR. Should be less than 64"""
1012
NUMBA_FAST_MATH = True
@@ -126,7 +128,8 @@ def _update(
126128
127129
Returns:Updated mean of all rows
128130
129-
"""
131+
132+
"""
130133
values = []
131134
sum_of_new: float = 0.0
132135
for i in range(self.count):
@@ -212,9 +215,9 @@ def _xor_float(a: float, b: float) -> float:
212215
Returns: XOR float value
213216
214217
"""
215-
a = int(a * (2**BITS))
216-
b = int(b * (2**BITS))
217-
return np.bitwise_xor(a, b) / 2**BITS
218+
a = int(a * (2 ** BITS))
219+
b = int(b * (2 ** BITS))
220+
return np.bitwise_xor(a, b) / 2 ** BITS
218221

219222
def __call__(self) -> tuple[float, float]:
220223
"""Interface for users
@@ -223,3 +226,51 @@ def __call__(self) -> tuple[float, float]:
223226
224227
"""
225228
return self.rqmc()
229+
230+
231+
class RQMCIntegrator:
232+
"""
233+
Randomize Quasi Monte Carlo Method
234+
235+
Args:
236+
error_tolerance: pre-specified error tolerance
237+
count: number of rows of random values matrix
238+
base_n: number of columns of random values matrix
239+
i_max: allowed number of cycles
240+
a: parameter for quantile of normal distribution
241+
242+
"""
243+
244+
def __init__(
245+
self,
246+
error_tolerance: float = 1e-6,
247+
count: int = 25,
248+
base_n: int = 2 ** 6,
249+
i_max: int = 100,
250+
a: float = 0.00047,
251+
):
252+
self.error_tolerance = error_tolerance
253+
self.count = count
254+
self.base_n = base_n
255+
self.i_max = i_max
256+
self.a = a
257+
258+
def compute(self, func: Callable) -> IntegrationResult:
259+
"""
260+
Compute integral via RQMC integrator
261+
262+
Args:
263+
func: integrated function
264+
265+
Returns: moment approximation and error tolerance
266+
"""
267+
result = RQMC(
268+
func,
269+
error_tolerance=self.error_tolerance,
270+
count=self.count,
271+
base_n=self.base_n,
272+
i_max=self.i_max,
273+
a=self.a,
274+
)()
275+
return IntegrationResult(result[0], result[1])
276+

src/mixtures/abstract_mixture.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from scipy.stats import rv_continuous
66
from scipy.stats.distributions import rv_frozen
77

8+
from src.algorithms.support_algorithms.integrator import Integrator
89

910
class AbstractMixtures(metaclass=ABCMeta):
1011
"""Base class for Mixtures"""
@@ -28,13 +29,13 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
2829
raise AssertionError(f"Unknown mixture form: {mixture_form}")
2930

3031
@abstractmethod
31-
def compute_moment(self, n: int, params: dict) -> tuple[float, float]: ...
32+
def compute_moment(self, n: int, integrator: Integrator) -> tuple[float, float]: ...
3233

3334
@abstractmethod
34-
def compute_cdf(self, x: float, params: dict) -> tuple[float, float]: ...
35+
def compute_cdf(self, x: float, integrator: Integrator) -> tuple[float, float]: ...
3536

3637
@abstractmethod
37-
def compute_pdf(self, x: float, params: dict) -> tuple[float, float]: ...
38+
def compute_pdf(self, x: float, integrator: Integrator) -> tuple[float, float]: ...
3839

3940
@abstractmethod
4041
def compute_logpdf(self, x: float, params: dict) -> tuple[float, float]: ...

0 commit comments

Comments
 (0)