Skip to content

Commit 5f7295c

Browse files
committed
fixed
1 parent 1b5138d commit 5f7295c

File tree

2 files changed

+41
-21
lines changed

2 files changed

+41
-21
lines changed

src/mixtures/nm_mixture.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from scipy.stats.distributions import rv_frozen
88

99
from src.algorithms.support_algorithms.integrator import Integrator
10+
from src.algorithms.support_algorithms.rqmc import RQMCIntegrator
1011
from src.mixtures.abstract_mixture import AbstractMixtures
1112

1213
@dataclass
@@ -28,7 +29,7 @@ class NormalMeanMixtures(AbstractMixtures):
2829
def __init__(
2930
self,
3031
mixture_form: str,
31-
integrator_cls: Type[Integrator] = Integrator,
32+
integrator_cls: Type[Integrator] = RQMCIntegrator,
3233
integrator_params: Dict[str, Any] = None,
3334
**kwargs: Any
3435
) -> None:

src/mixtures/nmv_mixture.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from functools import lru_cache
23
from typing import Any, Type, Dict, Tuple
34

45
import numpy as np
@@ -7,6 +8,8 @@
78
from scipy.stats.distributions import rv_frozen
89

910
from src.algorithms.support_algorithms.integrator import Integrator
11+
from src.algorithms.support_algorithms.rqmc import RQMCIntegrator
12+
from src.algorithms.support_algorithms.log_rqmc import LogRQMC
1013
from src.mixtures.abstract_mixture import AbstractMixtures
1114

1215
@dataclass
@@ -29,35 +32,37 @@ class NormalMeanVarianceMixtures(AbstractMixtures):
2932
def __init__(
3033
self,
3134
mixture_form: str,
32-
integrator_cls: Type[Integrator],
33-
integrator_params: Dict[str, Any] | None = None,
35+
integrator_cls: Type[Integrator] = RQMCIntegrator,
36+
integrator_params: Dict[str, Any] = None,
3437
**kwargs: Any
3538
) -> None:
3639
super().__init__(mixture_form, integrator_cls=integrator_cls, integrator_params=integrator_params, **kwargs)
3740

3841
def _compute_moment(self, n: int) -> Tuple[float, float]:
42+
gamma = getattr(self.params, 'gamma', None)
43+
3944
def integrand(u: float) -> float:
4045
s = 0.0
4146
for k in range(n + 1):
4247
for l in range(k + 1):
43-
if self.mixture_form == "classical":
44-
coef = binom(n, n - k) * binom(k, k - l)
48+
if self.mixture_form == 'classical':
4549
term = (
46-
(self.params.beta ** (k - l))
50+
binom(n, n - k)
51+
* binom(k, k - l)
52+
* (self.params.beta ** (k - l))
4753
* (self.params.gamma ** l)
4854
* (self.params.distribution.ppf(u) ** (k - l/2))
49-
* (self.params.alpha ** (n - k))
5055
* norm.moment(l)
5156
)
5257
else:
53-
coef = binom(n, n - k) * binom(k, k - l)
5458
term = (
55-
(self.params.nu ** (k - l))
59+
binom(n, n - k)
60+
* binom(k, k - l)
61+
* (self.params.mu ** (k - l))
5662
* (self.params.distribution.ppf(u) ** (k - l/2))
57-
* (self.params.alpha ** (n - k))
5863
* norm.moment(l)
5964
)
60-
s += coef * term if self.mixture_form == "classical" else term
65+
s += term
6166
return s
6267

6368
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
@@ -66,8 +71,8 @@ def integrand(u: float) -> float:
6671
def _compute_cdf(self, x: float) -> Tuple[float, float]:
6772
def integrand(u: float) -> float:
6873
p = self.params.distribution.ppf(u)
69-
if self.mixture_form == "classical":
70-
return norm.cdf((x - self.params.alpha - self.params.beta * p) / abs(self.params.gamma))
74+
if self.mixture_form == 'classical':
75+
return norm.cdf((x - self.params.alpha) / (np.sqrt(p) * self.params.gamma))
7176
return norm.cdf((x - self.params.alpha) / np.sqrt(p) - self.params.mu * np.sqrt(p))
7277

7378
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
@@ -76,19 +81,33 @@ def integrand(u: float) -> float:
7681
def _compute_pdf(self, x: float) -> Tuple[float, float]:
7782
def integrand(u: float) -> float:
7883
p = self.params.distribution.ppf(u)
79-
if self.mixture_form == "classical":
80-
return (1/abs(self.params.gamma)) * norm.pdf((x - self.params.alpha - self.params.beta * p)/abs(self.params.gamma))
81-
return (1/abs(self.params.mu)) * norm.pdf((x - p)/abs(self.params.mu))
84+
if self.mixture_form == 'classical':
85+
return (
86+
1 / np.sqrt(2 * np.pi * p * self.params.gamma ** 2)
87+
* np.exp(-((x - self.params.alpha) ** 2 + self.params.beta ** 2 * p ** 2) / (2 * p * self.params.gamma ** 2))
88+
)
89+
return (
90+
1 / np.sqrt(2 * np.pi * p)
91+
* np.exp(-((x - self.params.alpha) ** 2 + self.params.mu ** 2 * p ** 2) / (2 * p))
92+
)
8293

8394
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
84-
return res.value, res.error
95+
if self.mixture_form == 'classical':
96+
val = np.exp(self.params.beta * (x - self.params.alpha) / self.params.gamma ** 2) * res.value
97+
else:
98+
val = np.exp(self.params.mu * (x - self.params.alpha)) * res.value
99+
return val, res.error
85100

86101
def _compute_logpdf(self, x: float) -> Tuple[float, float]:
87102
def integrand(u: float) -> float:
88103
p = self.params.distribution.ppf(u)
89-
if self.mixture_form == "classical":
90-
return np.log(1/abs(self.params.gamma)) + norm.logpdf((x - self.params.alpha - self.params.beta * p)/abs(self.params.gamma))
91-
return np.log(1/abs(self.params.mu)) + norm.logpdf((x - p)/abs(self.params.mu))
104+
if self.mixture_form == 'classical':
105+
return -((x - self.params.alpha) ** 2 + p ** 2 * self.params.beta ** 2 + p * self.params.gamma ** 2 * np.log(2 * np.pi * p * self.params.gamma ** 2)) / (2 * p * self.params.gamma ** 2)
106+
return -((x - self.params.alpha) ** 2 + p ** 2 * self.params.mu ** 2 + p * np.log(2 * np.pi * p)) / (2 * p)
92107

93108
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
94-
return res.value, res.error
109+
if self.mixture_form == 'classical':
110+
val = self.params.beta * (x - self.params.alpha) / self.params.gamma ** 2 + res.value
111+
else:
112+
val = self.params.mu * (x - self.params.alpha) + res.value
113+
return val, res.error

0 commit comments

Comments
 (0)