Skip to content

Commit 51a7303

Browse files
committed
feat: integration parameters are now part of the constructor
1 parent 1280bdf commit 51a7303

File tree

8 files changed

+174
-143
lines changed

8 files changed

+174
-143
lines changed

src/algorithms/semiparam_algorithms/nvm_semi_param_algorithms/g_estimation_given_mu_rqmc_based.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def v_sequence_default_value(n: float) -> float:
3131
GRID_SIZE_DEFAULT_VALUE: int = 200
3232
INTEGRATION_TOLERANCE_DEFAULT_VALUE: float = 1e-2
3333
INTEGRATION_LIMIT_DEFAULT_VALUE: int = 50
34-
INTEGRATOR_DEFAULT = RQMCIntegrator()
3534

3635

3736
class SemiParametricGEstimationGivenMuRQMCBased:
@@ -70,8 +69,7 @@ def __init__(self, sample: Optional[_typing.NDArray[np.float64]] = None, **kwarg
7069
self.x_data,
7170
self.grid_size,
7271
self.integration_tolerance,
73-
self.integration_limit,
74-
self.integrator
72+
self.integration_limit
7573
) = self._validate_kwargs(self.n, **kwargs)
7674
self.denominator: float = 2 * math.pi * self.n
7775
self.precompute_gamma_grid()
@@ -81,7 +79,7 @@ def __init__(self, sample: Optional[_typing.NDArray[np.float64]] = None, **kwarg
8179
@staticmethod
8280
def _validate_kwargs(
8381
n: int, **kwargs: Unpack[ParamsAnnotation]
84-
) -> tuple[float, float, float, float, List[float], int, float, int, Integrator]:
82+
) -> tuple[float, float, float, float, List[float], int, float, int]:
8583
mu: float = kwargs.get("mu", MU_DEFAULT_VALUE)
8684
gmm: float = kwargs.get("gmm", GAMMA_DEFAULT_VALUE)
8785
u_value: float = kwargs.get("u_value", U_SEQUENCE_DEFAULT_VALUE(n))
@@ -90,8 +88,7 @@ def _validate_kwargs(
9088
grid_size: int = kwargs.get("grid_size", GRID_SIZE_DEFAULT_VALUE)
9189
integration_tolerance: float = kwargs.get("integration_tolerance", INTEGRATION_TOLERANCE_DEFAULT_VALUE)
9290
integration_limit: int = kwargs.get("integration_limit", INTEGRATION_LIMIT_DEFAULT_VALUE)
93-
integrator: Integrator = kwargs.get("integrator_type", INTEGRATOR_DEFAULT)
94-
return mu, gmm, u_value, v_value, x_data, grid_size, integration_tolerance, integration_limit, integrator
91+
return mu, gmm, u_value, v_value, x_data, grid_size, integration_tolerance, integration_limit
9592

9693
def conjugate_psi(self, u: float) -> complex:
9794
return complex((u**2) / 2, self.mu * u)
@@ -166,16 +163,15 @@ def second_v_integrand(self, v: float, x: float) -> np.ndarray:
166163
x_power = self.x_powers[x][idx]
167164
return (self.second_u_integrals[idx] * x_power) / gamma_val
168165

169-
def compute_integrals_for_x(self, x: float, integrator: Integrator = None) -> float:
166+
def compute_integrals_for_x(self, x: float, integrator: Integrator = RQMCIntegrator()) -> float:
170167
"""Compute integrals using RQMC for v-integration."""
171-
integrator = integrator or self.integrator
172-
first_integral = integrator.compute_integral(func=lambda t: np.sum(self.first_v_integrand(t * self.v_value, x)) * self.v_value).value
168+
first_integral = integrator.compute(func=lambda t: np.sum(self.first_v_integrand(t * self.v_value, x)) * self.v_value).value
173169

174-
second_integral = integrator.compute_integral(func=lambda t: np.sum(self.second_v_integrand(-t * self.v_value, x)) * self.v_value).value
170+
second_integral = integrator.compute(func=lambda t: np.sum(self.second_v_integrand(-t * self.v_value, x)) * self.v_value).value
175171

176172
total = (first_integral + second_integral) / self.denominator
177173
return max(0.0, total.real)
178174

179175
def algorithm(self, sample: np._typing.NDArray) -> EstimateResult:
180-
y_data = [self.compute_integrals_for_x(x, self.integrator) for x in self.x_data]
176+
y_data = [self.compute_integrals_for_x(x, RQMCIntegrator()) for x in self.x_data]
181177
return EstimateResult(list_value=y_data, success=True)

src/algorithms/support_algorithms/integrator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,8 @@ class Integrator(Protocol):
1313

1414
"""Base class for integral calculation"""
1515

16-
def compute_integral(self, func: Callable, params: dict) -> IntegrationResult:
16+
def __init__(self) -> None:
17+
...
18+
19+
def compute(self, func: Callable) -> IntegrationResult:
1720
...
Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,60 @@
1-
from typing import Callable
1+
from typing import Callable, Any, Sequence
22

33
from scipy.integrate import quad
44
from src.algorithms.support_algorithms.integrator import IntegrationResult
55

66
class QuadIntegrator:
77

8-
def compute_integral(self, func: Callable, params: dict) -> IntegrationResult:
8+
def __init__(
9+
self,
10+
a: float = 0,
11+
b: float = 1,
12+
args: tuple[Any, ...] = (),
13+
full_output: int = 0,
14+
epsabs: float | int = 1.49e-08,
15+
epsrel: float | int = 1.49e-08,
16+
limit: float | int = 50,
17+
points: Sequence[float | int] | None = None,
18+
weight: float | int | None = None,
19+
wvar: Any = None,
20+
wopts: Any = None,
21+
maxp1: float | int = 50,
22+
limlst: int = 50,
23+
complex_func: bool = False,
24+
):
25+
self.params = {
26+
'a': a,
27+
'b': b,
28+
'args': args,
29+
'full_output': full_output,
30+
'epsabs': epsabs,
31+
'epsrel': epsrel,
32+
'limit': limit,
33+
'points': points,
34+
'weight': weight,
35+
'wvar': wvar,
36+
'wopts': wopts,
37+
'maxp1': maxp1,
38+
'limlst': limlst,
39+
'complex_func': complex_func
40+
}
41+
42+
def compute(self, func: Callable) -> IntegrationResult:
943

1044
"""
1145
Compute integral via quad integrator
1246
1347
Args:
1448
func: integrated function
15-
params: Parameters of integration algorithm
1649
1750
Returns: moment approximation and error tolerance
1851
"""
1952

20-
full_output_requested = params.pop('full_output', False)
21-
quad_res = quad(func, **params)
22-
if full_output_requested:
23-
value, error, message = quad_res
53+
verbose = self.params.pop('full_output', False)
54+
result = quad(func, **self.params)
55+
if verbose:
56+
value, error, message = result
2457
else:
25-
value, error = quad_res
58+
value, error = result
2659
message = None
2760
return IntegrationResult(value, error, message)

src/algorithms/support_algorithms/rqmc.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,17 +229,48 @@ def __call__(self) -> tuple[float, float]:
229229

230230

231231
class RQMCIntegrator:
232+
"""
233+
Randomize Quasi Monte Carlo Method
234+
235+
Args:
236+
error_tolerance: pre-specified error tolerance
237+
count: number of rows of random values matrix
238+
base_n: number of columns of random values matrix
239+
i_max: allowed number of cycles
240+
a: parameter for quantile of normal distribution
241+
242+
"""
243+
244+
def __init__(
245+
self,
246+
error_tolerance: float = 1e-6,
247+
count: int = 25,
248+
base_n: int = 2 ** 6,
249+
i_max: int = 100,
250+
a: float = 0.00047,
251+
):
252+
self.error_tolerance = error_tolerance
253+
self.count = count
254+
self.base_n = base_n
255+
self.i_max = i_max
256+
self.a = a
232257

233-
def compute_integral(self, func: Callable, params: dict) -> IntegrationResult:
258+
def compute(self, func: Callable) -> IntegrationResult:
234259
"""
235260
Compute integral via RQMC integrator
236261
237262
Args:
238263
func: integrated function
239-
params: Parameters of integration algorithm
240264
241265
Returns: moment approximation and error tolerance
242266
"""
267+
result = RQMC(
268+
func,
269+
error_tolerance=self.error_tolerance,
270+
count=self.count,
271+
base_n=self.base_n,
272+
i_max=self.i_max,
273+
a=self.a,
274+
)()
275+
return IntegrationResult(result[0], result[1])
243276

244-
rqmc = RQMC(func, **params)()
245-
return IntegrationResult(rqmc[0], rqmc[1])

src/mixtures/abstract_mixture.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from scipy.stats import rv_continuous
66
from scipy.stats.distributions import rv_frozen
77

8+
from src.algorithms.support_algorithms.integrator import Integrator
89

910
class AbstractMixtures(metaclass=ABCMeta):
1011
"""Base class for Mixtures"""
@@ -28,13 +29,13 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
2829
raise AssertionError(f"Unknown mixture form: {mixture_form}")
2930

3031
@abstractmethod
31-
def compute_moment(self, n: int, params: dict) -> tuple[float, float]: ...
32+
def compute_moment(self, n: int, integrator: Integrator) -> tuple[float, float]: ...
3233

3334
@abstractmethod
34-
def compute_cdf(self, x: float, params: dict) -> tuple[float, float]: ...
35+
def compute_cdf(self, x: float, integrator: Integrator) -> tuple[float, float]: ...
3536

3637
@abstractmethod
37-
def compute_pdf(self, x: float, params: dict) -> tuple[float, float]: ...
38+
def compute_pdf(self, x: float, integrator: Integrator) -> tuple[float, float]: ...
3839

3940
@abstractmethod
4041
def compute_logpdf(self, x: float, params: dict) -> tuple[float, float]: ...

0 commit comments

Comments
 (0)