Skip to content

Commit 12b3205

Browse files
committed
fixed
1 parent 4296e34 commit 12b3205

File tree

4 files changed

+52
-79
lines changed

4 files changed

+52
-79
lines changed

src/mixtures/abstract_mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,4 @@ def _params_validation(
126126
raise ValueError(
127127
f"Type mismatch: {name} should be {names_and_types[name]}, not {type(value)}"
128128
)
129-
return data_collector(**params)
129+
return data_collector(**params)

src/mixtures/nm_mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def mix(u: float) -> float:
6969
error += coeff * res.error * norm.moment(k)
7070
return mixture_moment, error
7171

72-
def _compute_cdf(self, x: float) -> Tuple[float, float]:
72+
def _compute_cdf(self, x: float, params: Dict[str, Any]) -> Tuple[float, float]:
7373
if self.mixture_form == "classical":
7474
def fn(u: float) -> float:
7575
return norm.cdf((x - self.params.alpha - self.params.beta * self.params.distribution.ppf(u)) / abs(self.params.gamma))

src/mixtures/nmv_mixture.py

Lines changed: 37 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,25 @@
11
from dataclasses import dataclass
2-
from functools import lru_cache
3-
from typing import Any, Type, Dict
2+
from typing import Any, Type, Dict, Tuple
43

54
import numpy as np
65
from scipy.special import binom
76
from scipy.stats import norm, rv_continuous
87
from scipy.stats.distributions import rv_frozen
98

109
from src.algorithms.support_algorithms.integrator import Integrator
11-
from src.algorithms.support_algorithms.rqmc import RQMCIntegrator
12-
from src.algorithms.support_algorithms.log_rqmc import LogRQMC
13-
from src.algorithms.support_algorithms.quad_integrator import QuadIntegrator
1410
from src.mixtures.abstract_mixture import AbstractMixtures
1511

1612
@dataclass
1713
class _NMVMClassicDataCollector:
1814
alpha: float | int | np.int64
19-
beta: float | int | np.int64
15+
beta: float | int | np.int64
2016
gamma: float | int | np.int64
2117
distribution: rv_frozen | rv_continuous
2218

2319
@dataclass
2420
class _NMVMCanonicalDataCollector:
2521
alpha: float | int | np.int64
26-
mu: float | int | np.int64
22+
mu: float | int | np.int64
2723
distribution: rv_frozen | rv_continuous
2824

2925
class NormalMeanVarianceMixtures(AbstractMixtures):
@@ -33,96 +29,66 @@ class NormalMeanVarianceMixtures(AbstractMixtures):
3329
def __init__(
3430
self,
3531
mixture_form: str,
36-
integrator_cls: Type[Integrator] = RQMCIntegrator,
37-
integrator_params: Dict[str, Any] = None,
32+
integrator_cls: Type[Integrator],
33+
integrator_params: Dict[str, Any] | None = None,
3834
**kwargs: Any
3935
) -> None:
4036
super().__init__(mixture_form, integrator_cls=integrator_cls, integrator_params=integrator_params, **kwargs)
4137

42-
def _compute_moment(self, n: int) -> tuple[float, float]:
38+
def _compute_moment(self, n: int) -> Tuple[float, float]:
4339
def integrand(u: float) -> float:
44-
result = 0.0
40+
s = 0.0
4541
for k in range(n + 1):
4642
for l in range(k + 1):
4743
if self.mixture_form == "classical":
48-
result += (
49-
binom(n, n - k)
50-
* binom(k, k - l)
51-
* (self.params.beta ** (k - l))
44+
coef = binom(n, n - k) * binom(k, k - l)
45+
term = (
46+
(self.params.beta ** (k - l))
5247
* (self.params.gamma ** l)
5348
* (self.params.distribution.ppf(u) ** (k - l/2))
49+
* (self.params.alpha ** (n - k))
5450
* norm.moment(l)
5551
)
5652
else:
57-
result += (
58-
binom(n, n - k)
59-
* binom(k, k - l)
60-
* (self.params.mu ** (k - l))
53+
coef = binom(n, n - k) * binom(k, k - l)
54+
term = (
55+
(self.params.nu ** (k - l))
6156
* (self.params.distribution.ppf(u) ** (k - l/2))
57+
* (self.params.alpha ** (n - k))
6258
* norm.moment(l)
6359
)
64-
return result
60+
s += coef * term if self.mixture_form == "classical" else term
61+
return s
6562

66-
integrator = self.integrator_cls(**(self.integrator_params or {}))
67-
result = integrator.compute(integrand)
68-
return result.value, result.error
63+
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
64+
return res.value, res.error
6965

70-
def _compute_cdf(self, x: float) -> tuple[float, float]:
66+
def _compute_cdf(self, x: float) -> Tuple[float, float]:
7167
def integrand(u: float) -> float:
72-
ppf = self.params.distribution.ppf(u)
68+
p = self.params.distribution.ppf(u)
7369
if self.mixture_form == "classical":
74-
point = (x - self.params.alpha) / (np.sqrt(ppf) * self.params.gamma) - (self.params.beta / self.params.gamma * np.sqrt(ppf))
75-
else:
76-
point = (x - self.params.alpha) / np.sqrt(ppf) - (self.params.mu * np.sqrt(ppf))
77-
return norm.cdf(point)
70+
return norm.cdf((x - self.params.alpha - self.params.beta * p) / abs(self.params.gamma))
71+
return norm.cdf((x - self.params.alpha) / np.sqrt(p) - self.params.mu * np.sqrt(p))
7872

79-
integrator = self.integrator_cls(**(self.integrator_params or {}))
80-
result = integrator.compute(integrand)
81-
return result.value, result.error
73+
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
74+
return res.value, res.error
8275

83-
def _compute_pdf(self, x: float) -> tuple[float, float]:
76+
def _compute_pdf(self, x: float) -> Tuple[float, float]:
8477
def integrand(u: float) -> float:
85-
ppf = self.params.distribution.ppf(u)
78+
p = self.params.distribution.ppf(u)
8679
if self.mixture_form == "classical":
87-
return (
88-
1 / np.sqrt(2 * np.pi * ppf * self.params.gamma ** 2)
89-
* np.exp(-((x - self.params.alpha) ** 2 + self.params.beta ** 2 * ppf ** 2) / (2 * ppf * self.params.gamma ** 2))
90-
)
91-
else:
92-
return (
93-
1 / np.sqrt(2 * np.pi * ppf)
94-
* np.exp(-((x - self.params.alpha) ** 2 + self.params.mu ** 2 * ppf ** 2) / (2 * ppf))
95-
)
80+
return (1/abs(self.params.gamma)) * norm.pdf((x - self.params.alpha - self.params.beta * p)/abs(self.params.gamma))
81+
return (1/abs(self.params.mu)) * norm.pdf((x - p)/abs(self.params.mu))
9682

97-
integrator = self.integrator_cls(**(self.integrator_params or {}))
98-
result = integrator.compute(integrand)
99-
if self.mixture_form == "classical":
100-
val = np.exp(self.params.beta * (x - self.params.alpha) / self.params.gamma ** 2) * result.value
101-
else:
102-
val = np.exp(self.params.mu * (x - self.params.alpha)) * result.value
103-
return val, result.error
83+
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
84+
return res.value, res.error
10485

105-
def _compute_logpdf(self, x: float) -> tuple[float, float]:
86+
def _compute_logpdf(self, x: float) -> Tuple[float, float]:
10687
def integrand(u: float) -> float:
107-
ppf = self.params.distribution.ppf(u)
88+
p = self.params.distribution.ppf(u)
10889
if self.mixture_form == "classical":
109-
return -(
110-
(x - self.params.alpha) ** 2
111-
+ ppf ** 2 * self.params.beta ** 2
112-
+ ppf * self.params.gamma ** 2 * np.log(2 * np.pi * ppf * self.params.gamma ** 2)
113-
) / (2 * ppf * self.params.gamma ** 2)
114-
else:
115-
return -((x - self.params.alpha) ** 2 + ppf ** 2 * self.params.mu ** 2 + ppf * np.log(2 * np.pi * ppf)) / (2 * ppf)
90+
return np.log(1/abs(self.params.gamma)) + norm.logpdf((x - self.params.alpha - self.params.beta * p)/abs(self.params.gamma))
91+
return np.log(1/abs(self.params.mu)) + norm.logpdf((x - p)/abs(self.params.mu))
11692

117-
integrator = self.integrator_cls(**(self.integrator_params or {}))
118-
result = integrator.compute(integrand)
119-
if self.mixture_form == "classical":
120-
val = self.params.beta * (x - self.params.alpha) / self.params.gamma ** 2 + result.value
121-
else:
122-
val = self.params.mu * (x - self.params.alpha) + result.value
123-
return val, result.error
124-
125-
@lru_cache()
126-
def _integrand_func(self, u: float, d: float, gamma: float) -> float:
127-
ppf = self.params.distribution.ppf(u)
128-
return (1 / np.sqrt(2 * np.pi * ppf * abs(gamma) ** 2)) * np.exp(-d / (2 * ppf))
93+
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
94+
return res.value, res.error

src/mixtures/nv_mixture.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ def __init__(
3535
integrator_params: Dict[str, Any] = None,
3636
**kwargs: Any
3737
) -> None:
38-
super().__init__(mixture_form, **kwargs)
38+
super().__init__(mixture_form, integrator_cls=integrator_cls, integrator_params=integrator_params, **kwargs)
3939
self.integrator_cls = integrator_cls
4040
self.integrator_params = integrator_params or {}
4141

42-
def _compute_moment(self, n: int, __: Dict[str, Any]) -> tuple[float, float]:
42+
def _compute_moment(self, n: int) -> tuple[float, float]:
4343
gamma = getattr(self.params, 'gamma', 1)
44+
4445
def integrand(u: float) -> float:
4546
return sum(
4647
binom(n, k)
@@ -50,34 +51,40 @@ def integrand(u: float) -> float:
5051
* norm.moment(k)
5152
for k in range(n + 1)
5253
)
54+
5355
integrator = self.integrator_cls(**self.integrator_params)
5456
result = integrator.compute(integrand)
5557
return result.value, result.error
5658

57-
def _compute_cdf(self, x: float, __: Dict[str, Any]) -> tuple[float, float]:
59+
def _compute_cdf(self, x: float) -> tuple[float, float]:
5860
gamma = getattr(self.params, 'gamma', 1)
5961
param_norm = norm(0, gamma)
62+
6063
def integrand(u: float) -> float:
6164
return param_norm.cdf((x - self.params.alpha) / np.sqrt(self.params.distribution.ppf(u)))
65+
6266
integrator = self.integrator_cls(**self.integrator_params)
6367
result = integrator.compute(integrand)
6468
return result.value, result.error
6569

66-
def _compute_pdf(self, x: float, __: Dict[str, Any]) -> tuple[float, float]:
70+
def _compute_pdf(self, x: float) -> tuple[float, float]:
6771
gamma = getattr(self.params, 'gamma', 1)
6872
d = (x - self.params.alpha) ** 2 / gamma ** 2
73+
6974
def integrand(u: float) -> float:
7075
return self._integrand_func(u, d, gamma)
76+
7177
integrator = self.integrator_cls(**self.integrator_params)
7278
result = integrator.compute(integrand)
7379
return result.value, result.error
7480

75-
def _compute_logpdf(self, x: float, __: Dict[str, Any]) -> tuple[float, float]:
81+
def _compute_logpdf(self, x: float) -> tuple[float, float]:
7682
gamma = getattr(self.params, 'gamma', 1)
7783
d = (x - self.params.alpha) ** 2 / gamma ** 2
84+
7885
def integrand(u: float) -> float:
7986
return self._log_integrand_func(u, d, gamma)
80-
# For log-pdf you may choose LogRQMC or any integrator that supports it
87+
8188
integrator = self.integrator_cls(**self.integrator_params)
8289
result = integrator.compute(integrand)
8390
return result.value, result.error

0 commit comments

Comments
 (0)