Skip to content

Commit 1b5138d

Browse files
committed
fixed
1 parent 12b3205 commit 1b5138d

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

src/mixtures/nm_mixture.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
11
from dataclasses import dataclass
2-
from typing import Any, Type, Dict, Tuple, Union, List
2+
from typing import Any, Type, Dict, Tuple
33

44
import numpy as np
55
from scipy.special import binom
66
from scipy.stats import norm, rv_continuous
77
from scipy.stats.distributions import rv_frozen
88

99
from src.algorithms.support_algorithms.integrator import Integrator
10-
from src.algorithms.support_algorithms.quad_integrator import QuadIntegrator
11-
from src.algorithms.support_algorithms.rqmc import RQMCIntegrator
12-
from src.algorithms.support_algorithms.log_rqmc import LogRQMC
1310
from src.mixtures.abstract_mixture import AbstractMixtures
1411

1512
@dataclass
1613
class _NMMClassicDataCollector:
1714
alpha: float | int | np.int64
18-
beta: float | int | np.int64
15+
beta: float | int | np.int64
1916
gamma: float | int | np.int64
2017
distribution: rv_frozen | rv_continuous
2118

@@ -31,7 +28,7 @@ class NormalMeanMixtures(AbstractMixtures):
3128
def __init__(
3229
self,
3330
mixture_form: str,
34-
integrator_cls: Type[Integrator] = RQMCIntegrator,
31+
integrator_cls: Type[Integrator] = Integrator,
3532
integrator_params: Dict[str, Any] = None,
3633
**kwargs: Any
3734
) -> None:
@@ -51,53 +48,56 @@ def _compute_moment(self, n: int) -> Tuple[float, float]:
5148
if self.mixture_form == "classical":
5249
for k in range(n + 1):
5350
for l in range(k + 1):
54-
coeff = binom(n, n - k) * binom(k, k - l) * (self.params.beta ** (k - l)) * (self.params.gamma ** l)
51+
coeff = binom(n, n - k) * binom(k, k - l)
5552
def mix(u: float) -> float:
56-
return self.params.distribution.ppf(u) ** (k - l)
57-
integrator = self.integrator_cls(**(self.integrator_params or {}))
58-
res = integrator.compute(mix)
59-
mixture_moment += coeff * (self.params.alpha ** (n - k)) * res.value * norm.moment(l)
60-
error += coeff * res.error * (self.params.alpha ** (n - k)) * norm.moment(l)
53+
return (
54+
self.params.distribution.ppf(u) ** (k - l)
55+
)
56+
res = self.integrator_cls(**(self.integrator_params or {})).compute(mix)
57+
mixture_moment += coeff * (self.params.beta ** (k - l)) * (self.params.gamma ** l) * (self.params.alpha ** (n - k)) * res.value * norm.moment(l)
58+
error += coeff * (self.params.beta ** (k - l)) * (self.params.gamma ** l) * (self.params.alpha ** (n - k)) * res.error * norm.moment(l)
6159
else:
6260
for k in range(n + 1):
63-
coeff = binom(n, n - k) * (self.params.sigma ** k)
61+
coeff = binom(n, n - k)
6462
def mix(u: float) -> float:
6563
return self.params.distribution.ppf(u) ** (n - k)
66-
integrator = self.integrator_cls(**(self.integrator_params or {}))
67-
res = integrator.compute(mix)
68-
mixture_moment += coeff * res.value * norm.moment(k)
69-
error += coeff * res.error * norm.moment(k)
64+
res = self.integrator_cls(**(self.integrator_params or {})).compute(mix)
65+
mixture_moment += coeff * (self.params.sigma ** k) * res.value * norm.moment(k)
66+
error += coeff * (self.params.sigma ** k) * res.error * norm.moment(k)
7067
return mixture_moment, error
7168

72-
def _compute_cdf(self, x: float, params: Dict[str, Any]) -> Tuple[float, float]:
69+
def _compute_cdf(self, x: float) -> Tuple[float, float]:
7370
if self.mixture_form == "classical":
7471
def fn(u: float) -> float:
75-
return norm.cdf((x - self.params.alpha - self.params.beta * self.params.distribution.ppf(u)) / abs(self.params.gamma))
72+
p = self.params.distribution.ppf(u)
73+
return norm.cdf((x - self.params.alpha - self.params.beta * p) / abs(self.params.gamma))
7674
else:
7775
def fn(u: float) -> float:
78-
return norm.cdf((x - self.params.distribution.ppf(u)) / abs(self.params.sigma))
79-
integrator = self.integrator_cls(**(self.integrator_params or {}))
80-
res = integrator.compute(fn)
76+
p = self.params.distribution.ppf(u)
77+
return norm.cdf((x - p) / abs(self.params.sigma))
78+
res = self.integrator_cls(**(self.integrator_params or {})).compute(fn)
8179
return res.value, res.error
8280

8381
def _compute_pdf(self, x: float) -> Tuple[float, float]:
8482
if self.mixture_form == "classical":
8583
def fn(u: float) -> float:
86-
return (1 / abs(self.params.gamma)) * norm.pdf((x - self.params.alpha - self.params.beta * self.params.distribution.ppf(u)) / abs(self.params.gamma))
84+
p = self.params.distribution.ppf(u)
85+
return (1 / abs(self.params.gamma)) * norm.pdf((x - self.params.alpha - self.params.beta * p) / abs(self.params.gamma))
8786
else:
8887
def fn(u: float) -> float:
89-
return (1 / abs(self.params.sigma)) * norm.pdf((x - self.params.distribution.ppf(u)) / abs(self.params.sigma))
90-
integrator = self.integrator_cls(**(self.integrator_params or {}))
91-
res = integrator.compute(fn)
88+
p = self.params.distribution.ppf(u)
89+
return (1 / abs(self.params.sigma)) * norm.pdf((x - p) / abs(self.params.sigma))
90+
res = self.integrator_cls(**(self.integrator_params or {})).compute(fn)
9291
return res.value, res.error
9392

9493
def _compute_logpdf(self, x: float) -> Tuple[float, float]:
9594
if self.mixture_form == "classical":
9695
def fn(u: float) -> float:
97-
return np.log(1 / abs(self.params.gamma)) + norm.logpdf((x - self.params.alpha - self.params.beta * self.params.distribution.ppf(u)) / abs(self.params.gamma))
96+
p = self.params.distribution.ppf(u)
97+
return np.log(1 / abs(self.params.gamma)) + norm.logpdf((x - self.params.alpha - self.params.beta * p) / abs(self.params.gamma))
9898
else:
9999
def fn(u: float) -> float:
100-
return np.log(1 / abs(self.params.sigma)) + norm.logpdf((x - self.params.distribution.ppf(u)) / abs(self.params.sigma))
101-
integrator = self.integrator_cls(**(self.integrator_params or {}))
102-
res = integrator.compute(fn)
100+
p = self.params.distribution.ppf(u)
101+
return np.log(1 / abs(self.params.sigma)) + norm.logpdf((x - p) / abs(self.params.sigma))
102+
res = self.integrator_cls(**(self.integrator_params or {})).compute(fn)
103103
return res.value, res.error

0 commit comments

Comments
 (0)