11from dataclasses import dataclass
2+ from functools import lru_cache
23from typing import Any , Type , Dict , Tuple
34
45import numpy as np
78from scipy .stats .distributions import rv_frozen
89
910from src .algorithms .support_algorithms .integrator import Integrator
11+ from src .algorithms .support_algorithms .rqmc import RQMCIntegrator
12+ from src .algorithms .support_algorithms .log_rqmc import LogRQMC
1013from src .mixtures .abstract_mixture import AbstractMixtures
1114
1215@dataclass
@@ -29,35 +32,37 @@ class NormalMeanVarianceMixtures(AbstractMixtures):
2932 def __init__ (
3033 self ,
3134 mixture_form : str ,
32- integrator_cls : Type [Integrator ],
33- integrator_params : Dict [str , Any ] | None = None ,
35+ integrator_cls : Type [Integrator ] = RQMCIntegrator ,
36+ integrator_params : Dict [str , Any ] = None ,
3437 ** kwargs : Any
3538 ) -> None :
3639 super ().__init__ (mixture_form , integrator_cls = integrator_cls , integrator_params = integrator_params , ** kwargs )
3740
3841 def _compute_moment (self , n : int ) -> Tuple [float , float ]:
42+ gamma = getattr (self .params , 'gamma' , None )
43+
3944 def integrand (u : float ) -> float :
4045 s = 0.0
4146 for k in range (n + 1 ):
4247 for l in range (k + 1 ):
43- if self .mixture_form == "classical" :
44- coef = binom (n , n - k ) * binom (k , k - l )
48+ if self .mixture_form == 'classical' :
4549 term = (
46- (self .params .beta ** (k - l ))
50+ binom (n , n - k )
51+ * binom (k , k - l )
52+ * (self .params .beta ** (k - l ))
4753 * (self .params .gamma ** l )
4854 * (self .params .distribution .ppf (u ) ** (k - l / 2 ))
49- * (self .params .alpha ** (n - k ))
5055 * norm .moment (l )
5156 )
5257 else :
53- coef = binom (n , n - k ) * binom (k , k - l )
5458 term = (
55- (self .params .nu ** (k - l ))
59+ binom (n , n - k )
60+ * binom (k , k - l )
61+ * (self .params .mu ** (k - l ))
5662 * (self .params .distribution .ppf (u ) ** (k - l / 2 ))
57- * (self .params .alpha ** (n - k ))
5863 * norm .moment (l )
5964 )
60- s += coef * term if self . mixture_form == "classical" else term
65+ s += term
6166 return s
6267
6368 res = self .integrator_cls (** (self .integrator_params or {})).compute (integrand )
@@ -66,8 +71,8 @@ def integrand(u: float) -> float:
6671 def _compute_cdf (self , x : float ) -> Tuple [float , float ]:
6772 def integrand (u : float ) -> float :
6873 p = self .params .distribution .ppf (u )
69- if self .mixture_form == " classical" :
70- return norm .cdf ((x - self .params .alpha - self . params . beta * p ) / abs ( self .params .gamma ))
74+ if self .mixture_form == ' classical' :
75+ return norm .cdf ((x - self .params .alpha ) / ( np . sqrt ( p ) * self .params .gamma ))
7176 return norm .cdf ((x - self .params .alpha ) / np .sqrt (p ) - self .params .mu * np .sqrt (p ))
7277
7378 res = self .integrator_cls (** (self .integrator_params or {})).compute (integrand )
@@ -76,19 +81,33 @@ def integrand(u: float) -> float:
7681 def _compute_pdf (self , x : float ) -> Tuple [float , float ]:
7782 def integrand (u : float ) -> float :
7883 p = self .params .distribution .ppf (u )
79- if self .mixture_form == "classical" :
80- return (1 / abs (self .params .gamma )) * norm .pdf ((x - self .params .alpha - self .params .beta * p )/ abs (self .params .gamma ))
81- return (1 / abs (self .params .mu )) * norm .pdf ((x - p )/ abs (self .params .mu ))
84+ if self .mixture_form == 'classical' :
85+ return (
86+ 1 / np .sqrt (2 * np .pi * p * self .params .gamma ** 2 )
87+ * np .exp (- ((x - self .params .alpha ) ** 2 + self .params .beta ** 2 * p ** 2 ) / (2 * p * self .params .gamma ** 2 ))
88+ )
89+ return (
90+ 1 / np .sqrt (2 * np .pi * p )
91+ * np .exp (- ((x - self .params .alpha ) ** 2 + self .params .mu ** 2 * p ** 2 ) / (2 * p ))
92+ )
8293
8394 res = self .integrator_cls (** (self .integrator_params or {})).compute (integrand )
84- return res .value , res .error
95+ if self .mixture_form == 'classical' :
96+ val = np .exp (self .params .beta * (x - self .params .alpha ) / self .params .gamma ** 2 ) * res .value
97+ else :
98+ val = np .exp (self .params .mu * (x - self .params .alpha )) * res .value
99+ return val , res .error
85100
86101 def _compute_logpdf (self , x : float ) -> Tuple [float , float ]:
87102 def integrand (u : float ) -> float :
88103 p = self .params .distribution .ppf (u )
89- if self .mixture_form == " classical" :
90- return np . log ( 1 / abs ( self .params .gamma )) + norm . logpdf (( x - self .params .alpha - self .params .beta * p ) / abs ( self .params .gamma ) )
91- return np . log ( 1 / abs ( self .params .mu )) + norm . logpdf (( x - p ) / abs ( self .params .mu ) )
104+ if self .mixture_form == ' classical' :
105+ return - (( x - self .params .alpha ) ** 2 + p ** 2 * self .params .beta ** 2 + p * self .params .gamma ** 2 * np . log ( 2 * np . pi * p * self . params . gamma ** 2 )) / ( 2 * p * self .params .gamma ** 2 )
106+ return - (( x - self .params .alpha ) ** 2 + p ** 2 * self .params .mu ** 2 + p * np . log ( 2 * np . pi * p )) / ( 2 * p )
92107
93108 res = self .integrator_cls (** (self .integrator_params or {})).compute (integrand )
94- return res .value , res .error
109+ if self .mixture_form == 'classical' :
110+ val = self .params .beta * (x - self .params .alpha ) / self .params .gamma ** 2 + res .value
111+ else :
112+ val = self .params .mu * (x - self .params .alpha ) + res .value
113+ return val , res .error
0 commit comments