Skip to content

Commit 19a9ca1

Browse files
committed
feat: Implement dynamic integrator selection, enabling the use of custom integrators
1 parent d7b56dd commit 19a9ca1

File tree

7 files changed

+187
-86
lines changed

7 files changed

+187
-86
lines changed

src/algorithms/semiparam_algorithms/nvm_semi_param_algorithms/g_estimation_given_mu_rqmc_based.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from scipy.integrate import quad_vec
88
from scipy.special import gamma
99

10-
from src.algorithms.support_algorithms.rqmc import RQMC
10+
from src.algorithms.support_algorithms.integrator import Integrator
11+
from src.algorithms.support_algorithms.rqmc import RQMCIntegrator
1112
from src.estimators.estimate_result import EstimateResult
1213

1314
MU_DEFAULT_VALUE = 1.0
@@ -30,6 +31,7 @@ def v_sequence_default_value(n: float) -> float:
3031
GRID_SIZE_DEFAULT_VALUE: int = 200
3132
INTEGRATION_TOLERANCE_DEFAULT_VALUE: float = 1e-2
3233
INTEGRATION_LIMIT_DEFAULT_VALUE: int = 50
34+
INTEGRATOR_DEFAULT = RQMCIntegrator()
3335

3436

3537
class SemiParametricGEstimationGivenMuRQMCBased:
@@ -69,6 +71,7 @@ def __init__(self, sample: Optional[_typing.NDArray[np.float64]] = None, **kwarg
6971
self.grid_size,
7072
self.integration_tolerance,
7173
self.integration_limit,
74+
self.integrator
7275
) = self._validate_kwargs(self.n, **kwargs)
7376
self.denominator: float = 2 * math.pi * self.n
7477
self.precompute_gamma_grid()
@@ -78,7 +81,7 @@ def __init__(self, sample: Optional[_typing.NDArray[np.float64]] = None, **kwarg
7881
@staticmethod
7982
def _validate_kwargs(
8083
n: int, **kwargs: Unpack[ParamsAnnotation]
81-
) -> tuple[float, float, float, float, List[float], int, float, int]:
84+
) -> tuple[float, float, float, float, List[float], int, float, int, Integrator]:
8285
mu: float = kwargs.get("mu", MU_DEFAULT_VALUE)
8386
gmm: float = kwargs.get("gmm", GAMMA_DEFAULT_VALUE)
8487
u_value: float = kwargs.get("u_value", U_SEQUENCE_DEFAULT_VALUE(n))
@@ -87,7 +90,8 @@ def _validate_kwargs(
8790
grid_size: int = kwargs.get("grid_size", GRID_SIZE_DEFAULT_VALUE)
8891
integration_tolerance: float = kwargs.get("integration_tolerance", INTEGRATION_TOLERANCE_DEFAULT_VALUE)
8992
integration_limit: int = kwargs.get("integration_limit", INTEGRATION_LIMIT_DEFAULT_VALUE)
90-
return mu, gmm, u_value, v_value, x_data, grid_size, integration_tolerance, integration_limit
93+
integrator: Integrator = kwargs.get("integrator_type", INTEGRATOR_DEFAULT)
94+
return mu, gmm, u_value, v_value, x_data, grid_size, integration_tolerance, integration_limit, integrator
9195

9296
def conjugate_psi(self, u: float) -> complex:
9397
return complex((u**2) / 2, self.mu * u)
@@ -162,15 +166,16 @@ def second_v_integrand(self, v: float, x: float) -> np.ndarray:
162166
x_power = self.x_powers[x][idx]
163167
return (self.second_u_integrals[idx] * x_power) / gamma_val
164168

165-
def compute_integrals_for_x(self, x: float) -> float:
169+
def compute_integrals_for_x(self, x: float, integrator: Integrator = None) -> float:
166170
"""Compute integrals using RQMC for v-integration."""
167-
first_integral = RQMC(lambda t: np.sum(self.first_v_integrand(t * self.v_value, x)) * self.v_value).rqmc()[0]
171+
integrator = integrator or self.integrator
172+
first_integral = integrator.compute_integral(func=lambda t: np.sum(self.first_v_integrand(t * self.v_value, x)) * self.v_value).value
168173

169-
second_integral = RQMC(lambda t: np.sum(self.second_v_integrand(-t * self.v_value, x)) * self.v_value).rqmc()[0]
174+
second_integral = integrator.compute_integral(func=lambda t: np.sum(self.second_v_integrand(-t * self.v_value, x)) * self.v_value).value
170175

171176
total = (first_integral + second_integral) / self.denominator
172177
return max(0.0, total.real)
173178

174179
def algorithm(self, sample: np._typing.NDArray) -> EstimateResult:
175-
y_data = [self.compute_integrals_for_x(x) for x in self.x_data]
180+
y_data = [self.compute_integrals_for_x(x, self.integrator) for x in self.x_data]
176181
return EstimateResult(list_value=y_data, success=True)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Protocol, Callable, Optional
3+
4+
5+
@dataclass
6+
class IntegrationResult:
7+
value: float
8+
error: float
9+
message: Optional[dict[str, Any]] | None = None
10+
11+
12+
class Integrator(Protocol):
13+
14+
"""Base class for integral calculation"""
15+
16+
def compute_integral(self, func: Callable, params: dict) -> IntegrationResult:
17+
...
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Callable
2+
3+
from scipy.integrate import quad
4+
from src.algorithms.support_algorithms.integrator import IntegrationResult
5+
6+
class QuadIntegrator:
7+
8+
def compute_integral(self, func: Callable, params: dict) -> IntegrationResult:
9+
10+
"""
11+
Compute integral via quad integrator
12+
13+
Args:
14+
func: integrated function
15+
params: Parameters of integration algorithm
16+
17+
Returns: moment approximation and error tolerance
18+
"""
19+
20+
full_output_requested = params.pop('full_output', False)
21+
quad_res = quad(func, **params)
22+
if full_output_requested:
23+
value, error, message = quad_res
24+
else:
25+
value, error = quad_res
26+
message = None
27+
return IntegrationResult(value, error, message)

src/algorithms/support_algorithms/rqmc.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import scipy
66
from numba import njit
77

8+
from src.algorithms.support_algorithms.integrator import IntegrationResult
9+
810
BITS = 30
911
"""Number of bits in XOR. Should be less than 64"""
1012
NUMBA_FAST_MATH = True
@@ -126,7 +128,8 @@ def _update(
126128
127129
Returns:Updated mean of all rows
128130
129-
"""
131+
132+
"""
130133
values = []
131134
sum_of_new: float = 0.0
132135
for i in range(self.count):
@@ -212,9 +215,9 @@ def _xor_float(a: float, b: float) -> float:
212215
Returns: XOR float value
213216
214217
"""
215-
a = int(a * (2**BITS))
216-
b = int(b * (2**BITS))
217-
return np.bitwise_xor(a, b) / 2**BITS
218+
a = int(a * (2 ** BITS))
219+
b = int(b * (2 ** BITS))
220+
return np.bitwise_xor(a, b) / 2 ** BITS
218221

219222
def __call__(self) -> tuple[float, float]:
220223
"""Interface for users
@@ -223,3 +226,20 @@ def __call__(self) -> tuple[float, float]:
223226
224227
"""
225228
return self.rqmc()
229+
230+
231+
class RQMCIntegrator:
232+
233+
def compute_integral(self, func: Callable, params: dict) -> IntegrationResult:
234+
"""
235+
Compute integral via RQMC integrator
236+
237+
Args:
238+
func: integrated function
239+
params: Parameters of integration algorithm
240+
241+
Returns: moment approximation and error tolerance
242+
"""
243+
244+
rqmc = RQMC(func, **params)()
245+
return IntegrationResult(rqmc[0], rqmc[1])

src/mixtures/nm_mixture.py

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from typing import Any
33

44
import numpy as np
5-
from scipy.integrate import quad
65
from scipy.special import binom
76
from scipy.stats import norm, rv_continuous
87
from scipy.stats.distributions import rv_frozen
98

109
from src.algorithms.support_algorithms.log_rqmc import LogRQMC
11-
from src.algorithms.support_algorithms.rqmc import RQMC
10+
from src.algorithms.support_algorithms.integrator import Integrator
11+
from src.algorithms.support_algorithms.rqmc import RQMCIntegrator
12+
from src.algorithms.support_algorithms.quad_integrator import QuadIntegrator
1213
from src.mixtures.abstract_mixture import AbstractMixtures
1314

1415

@@ -59,156 +60,175 @@ def _params_validation(self, data_collector: Any, params: dict[str, float | rv_c
5960
raise ValueError("Gamma cant be zero")
6061
return data_class
6162

62-
def _classical_moment(self, n: int, params: dict) -> tuple[float, float]:
63+
def _classical_moment(self, n: int, params: dict, integrator: Integrator = None) -> tuple[float, float]:
6364
"""
6465
Compute n-th moment of classical NMM
6566
6667
Args:
6768
n (): Moment ordinal
6869
params (): Parameters of integration algorithm
70+
integrator (): type of integrator to computing
6971
7072
Returns: moment approximation and error tolerance
7173
7274
"""
7375
mixture_moment = 0
7476
error_tolerance = 0
77+
integrator = integrator or QuadIntegrator()
7578
for k in range(0, n + 1):
7679
for l in range(0, k + 1):
77-
coefficient = binom(n, n - k) * binom(k, k - l) * (self.params.beta ** (k - l)) * (self.params.gamma**l)
78-
mixing_moment = quad(lambda u: self.params.distribution.ppf(u) ** (k - l), 0, 1, **params)
79-
error_tolerance += (self.params.beta ** (k - l)) * mixing_moment[1]
80-
mixture_moment += coefficient * (self.params.alpha ** (n - k)) * mixing_moment[0] * norm.moment(l)
81-
return mixture_moment, error_tolerance
80+
coefficient = binom(n, n - k) * binom(k, k - l) * (self.params.beta ** (k - l)) * (
81+
self.params.gamma ** l)
82+
mixing_moment = integrator.compute_integral(lambda u: self.params.distribution.ppf(u) ** (k - l),
83+
{"a": 0, "b": 1, **params})
84+
error_tolerance += (self.params.beta ** (k - l)) * mixing_moment.error
85+
mixture_moment += coefficient * (self.params.alpha ** (n - k)) * mixing_moment.value * norm.moment(l)
86+
return mixture_moment, error_tolerance
87+
8288

83-
def _canonical_moment(self, n: int, params: dict) -> tuple[float, float]:
89+
def _canonical_moment(self, n: int, params: dict, integrator: Integrator = None) -> tuple[float, float]:
8490
"""
8591
Compute n-th moment of canonical NMM
8692
8793
Args:
8894
n (): Moment ordinal
8995
params (): Parameters of integration algorithm
96+
integrator (): type of integrator to computing
9097
9198
Returns: moment approximation and error tolerance
9299
93100
"""
94101
mixture_moment = 0
95102
error_tolerance = 0
103+
integrator = integrator or QuadIntegrator()
96104
for k in range(0, n + 1):
97-
coefficient = binom(n, n - k) * (self.params.sigma**k)
98-
mixing_moment = quad(lambda u: self.params.distribution.ppf(u) ** (n - k), 0, 1, **params)
99-
error_tolerance += mixing_moment[1]
100-
mixture_moment += coefficient * mixing_moment[0] * norm.moment(k)
105+
coefficient = binom(n, n - k) * (self.params.sigma ** k)
106+
mixing_moment = integrator.compute_integral(lambda u: self.params.distribution.ppf(u) ** (n - k),
107+
{"a": 0, "b": 1, **params})
108+
error_tolerance += mixing_moment.error
109+
mixture_moment += coefficient * mixing_moment.value * norm.moment(k)
101110
return mixture_moment, error_tolerance
102111

103-
def compute_moment(self, n: int, params: dict) -> tuple[float, float]:
112+
def compute_moment(self, n: int, params: dict, integrator: Integrator = None) -> tuple[float, float]:
104113
"""
105114
Compute n-th moment of NMM
106115
107116
Args:
108117
n (): Moment ordinal
109118
params (): Parameters of integration algorithm
119+
integrator (): type of integrator to computing
110120
111121
Returns: moment approximation and error tolerance
112122
113123
"""
114124
if isinstance(self.params, _NMMClassicDataCollector):
115-
return self._classical_moment(n, params)
116-
return self._canonical_moment(n, params)
125+
return self._classical_moment(n, params, integrator)
126+
return self._canonical_moment(n, params, integrator)
117127

118-
def _canonical_compute_cdf(self, x: float, params: dict) -> tuple[float, float]:
128+
def _canonical_compute_cdf(self, x: float, params: dict, integrator: Integrator = None) -> tuple[float, float]:
119129
"""
120130
Equation for canonical cdf
121131
Args:
122132
x (): point
123133
params (): parameters of RQMC algorithm
134+
integrator (): type of integrator to computing
124135
125136
Returns: computed cdf and error tolerance
126137
127138
"""
128-
rqmc = RQMC(lambda u: norm.cdf((x - self.params.distribution.ppf(u)) / np.abs(self.params.sigma)), **params)
129-
return rqmc()
139+
integrator = integrator or RQMCIntegrator()
140+
rqmc = integrator.compute_integral(func=lambda u: norm.cdf((x - self.params.distribution.ppf(u)) / np.abs(self.params.sigma)), **params)
141+
return rqmc.value, rqmc.error
130142

131-
def _classical_compute_cdf(self, x: float, params: dict) -> tuple[float, float]:
143+
def _classical_compute_cdf(self, x: float, params: dict, integrator: Integrator = None) -> tuple[float, float]:
132144
"""
133145
Equation for classic cdf
134146
Args:
135147
x (): point
136148
params (): parameters of RQMC algorithm
149+
integrator (): type of integrator to computing
137150
138151
Returns: computed cdf and error tolerance
139152
140153
"""
141-
rqmc = RQMC(
154+
integrator = integrator or RQMCIntegrator()
155+
rqmc = integrator.compute_integral(func=
142156
lambda u: norm.cdf(
143157
(x - self.params.alpha - self.params.beta * self.params.distribution.ppf(u)) / np.abs(self.params.gamma)
144158
),
145159
**params
146160
)
147-
return rqmc()
161+
return rqmc.value, rqmc.error
148162

149-
def compute_cdf(self, x: float, params: dict) -> tuple[float, float]:
163+
def compute_cdf(self, x: float, params: dict, integrator: Integrator = None) -> tuple[float, float]:
150164
"""
151165
Choose equation for cdf estimation depends on Mixture form
152166
Args:
153167
x (): point
154168
params (): parameters of RQMC algorithm
169+
integrator (): type of integrator to computing
155170
156171
Returns: Computed pdf and error tolerance
157172
158173
"""
159174
if isinstance(self.params, _NMMCanonicalDataCollector):
160-
return self._canonical_compute_cdf(x, params)
161-
return self._classical_compute_cdf(x, params)
175+
return self._canonical_compute_cdf(x, params, integrator)
176+
return self._classical_compute_cdf(x, params, integrator)
162177

163-
def _canonical_compute_pdf(self, x: float, params: dict) -> tuple[float, float]:
178+
def _canonical_compute_pdf(self, x: float, params: dict, integrator: Integrator = None) -> tuple[float, float]:
164179
"""
165180
Equation for canonical pdf
166181
Args:
167182
x (): point
168183
params (): parameters of RQMC algorithm
184+
integrator (): type of integrator to computing
169185
170186
Returns: computed pdf and error tolerance
171187
172188
"""
173-
rqmc = RQMC(
189+
integrator = integrator or RQMCIntegrator()
190+
rqmc = integrator.compute_integral(func=
174191
lambda u: (1 / np.abs(self.params.sigma))
175192
* norm.pdf((x - self.params.distribution.ppf(u)) / np.abs(self.params.sigma)),
176193
**params
177194
)
178-
return rqmc()
195+
return rqmc.value, rqmc.error
179196

180-
def _classical_compute_pdf(self, x: float, params: dict) -> tuple[float, float]:
197+
def _classical_compute_pdf(self, x: float, params: dict, integrator: Integrator = None) -> tuple[float, float]:
181198
"""
182199
Equation for classic pdf
183200
Args:
184201
x (): point
185202
params (): parameters of RQMC algorithm
203+
integrator (): type of integrator to computing
186204
187205
Returns: computed pdf and error tolerance
188206
189207
"""
190-
rqmc = RQMC(
208+
integrator = integrator or RQMCIntegrator()
209+
rqmc = integrator.compute_integral(func=
191210
lambda u: (1 / np.abs(self.params.gamma))
192211
* norm.pdf(
193212
(x - self.params.alpha - self.params.beta * self.params.distribution.ppf(u)) / np.abs(self.params.gamma)
194213
),
195214
**params
196215
)
197-
return rqmc()
216+
return rqmc.value, rqmc.error
198217

199-
def compute_pdf(self, x: float, params: dict) -> tuple[float, float]:
218+
def compute_pdf(self, x: float, params: dict, integrator: Integrator = None) -> tuple[float, float]:
200219
"""
201220
Choose equation for pdf estimation depends on Mixture form
202221
Args:
203222
x (): point
204223
params (): parameters of RQMC algorithm
224+
integrator (): type of integrator to computing
205225
206226
Returns: Computed pdf and error tolerance
207227
208228
"""
209229
if isinstance(self.params, _NMMCanonicalDataCollector):
210-
return self._canonical_compute_pdf(x, params)
211-
return self._classical_compute_pdf(x, params)
230+
return self._canonical_compute_pdf(x, params, integrator)
231+
return self._classical_compute_pdf(x, params, integrator)
212232

213233
def _classical_compute_log_pdf(self, x: float, params: dict) -> tuple[float, float]:
214234
"""
@@ -258,4 +278,4 @@ def compute_logpdf(self, x: float, params: dict) -> tuple[float, float]:
258278
"""
259279
if isinstance(self.params, _NMMCanonicalDataCollector):
260280
return self._canonical_compute_log_pdf(x, params)
261-
return self._classical_compute_log_pdf(x, params)
281+
return self._classical_compute_log_pdf(x, params)

0 commit comments

Comments
 (0)