11from dataclasses import dataclass
2- from functools import lru_cache
3- from typing import Any , Type , Dict
2+ from typing import Any , Type , Dict , Tuple
43
54import numpy as np
65from scipy .special import binom
76from scipy .stats import norm , rv_continuous
87from scipy .stats .distributions import rv_frozen
98
109from src .algorithms .support_algorithms .integrator import Integrator
11- from src .algorithms .support_algorithms .rqmc import RQMCIntegrator
12- from src .algorithms .support_algorithms .log_rqmc import LogRQMC
13- from src .algorithms .support_algorithms .quad_integrator import QuadIntegrator
1410from src .mixtures .abstract_mixture import AbstractMixtures
1511
1612@dataclass
1713class _NMVMClassicDataCollector :
1814 alpha : float | int | np .int64
19- beta : float | int | np .int64
15+ beta : float | int | np .int64
2016 gamma : float | int | np .int64
2117 distribution : rv_frozen | rv_continuous
2218
2319@dataclass
2420class _NMVMCanonicalDataCollector :
2521 alpha : float | int | np .int64
26- mu : float | int | np .int64
22+ mu : float | int | np .int64
2723 distribution : rv_frozen | rv_continuous
2824
2925class NormalMeanVarianceMixtures (AbstractMixtures ):
@@ -33,96 +29,66 @@ class NormalMeanVarianceMixtures(AbstractMixtures):
3329 def __init__ (
3430 self ,
3531 mixture_form : str ,
36- integrator_cls : Type [Integrator ] = RQMCIntegrator ,
37- integrator_params : Dict [str , Any ] = None ,
32+ integrator_cls : Type [Integrator ],
33+ integrator_params : Dict [str , Any ] | None = None ,
3834 ** kwargs : Any
3935 ) -> None :
4036 super ().__init__ (mixture_form , integrator_cls = integrator_cls , integrator_params = integrator_params , ** kwargs )
4137
42- def _compute_moment (self , n : int ) -> tuple [float , float ]:
38+ def _compute_moment (self , n : int ) -> Tuple [float , float ]:
4339 def integrand (u : float ) -> float :
44- result = 0.0
40+ s = 0.0
4541 for k in range (n + 1 ):
4642 for l in range (k + 1 ):
4743 if self .mixture_form == "classical" :
48- result += (
49- binom (n , n - k )
50- * binom (k , k - l )
51- * (self .params .beta ** (k - l ))
44+ coef = binom (n , n - k ) * binom (k , k - l )
45+ term = (
46+ (self .params .beta ** (k - l ))
5247 * (self .params .gamma ** l )
5348 * (self .params .distribution .ppf (u ) ** (k - l / 2 ))
49+ * (self .params .alpha ** (n - k ))
5450 * norm .moment (l )
5551 )
5652 else :
57- result += (
58- binom (n , n - k )
59- * binom (k , k - l )
60- * (self .params .mu ** (k - l ))
53+ coef = binom (n , n - k ) * binom (k , k - l )
54+ term = (
55+ (self .params .nu ** (k - l ))
6156 * (self .params .distribution .ppf (u ) ** (k - l / 2 ))
57+ * (self .params .alpha ** (n - k ))
6258 * norm .moment (l )
6359 )
64- return result
60+ s += coef * term if self .mixture_form == "classical" else term
61+ return s
6562
66- integrator = self .integrator_cls (** (self .integrator_params or {}))
67- result = integrator .compute (integrand )
68- return result .value , result .error
63+ res = self .integrator_cls (** (self .integrator_params or {})).compute (integrand )
64+ return res .value , res .error
6965
70- def _compute_cdf (self , x : float ) -> tuple [float , float ]:
66+ def _compute_cdf (self , x : float ) -> Tuple [float , float ]:
7167 def integrand (u : float ) -> float :
72- ppf = self .params .distribution .ppf (u )
68+ p = self .params .distribution .ppf (u )
7369 if self .mixture_form == "classical" :
74- point = (x - self .params .alpha ) / (np .sqrt (ppf ) * self .params .gamma ) - (self .params .beta / self .params .gamma * np .sqrt (ppf ))
75- else :
76- point = (x - self .params .alpha ) / np .sqrt (ppf ) - (self .params .mu * np .sqrt (ppf ))
77- return norm .cdf (point )
70+ return norm .cdf ((x - self .params .alpha - self .params .beta * p ) / abs (self .params .gamma ))
71+ return norm .cdf ((x - self .params .alpha ) / np .sqrt (p ) - self .params .mu * np .sqrt (p ))
7872
79- integrator = self .integrator_cls (** (self .integrator_params or {}))
80- result = integrator .compute (integrand )
81- return result .value , result .error
73+ res = self .integrator_cls (** (self .integrator_params or {})).compute (integrand )
74+ return res .value , res .error
8275
83- def _compute_pdf (self , x : float ) -> tuple [float , float ]:
76+ def _compute_pdf (self , x : float ) -> Tuple [float , float ]:
8477 def integrand (u : float ) -> float :
85- ppf = self .params .distribution .ppf (u )
78+ p = self .params .distribution .ppf (u )
8679 if self .mixture_form == "classical" :
87- return (
88- 1 / np .sqrt (2 * np .pi * ppf * self .params .gamma ** 2 )
89- * np .exp (- ((x - self .params .alpha ) ** 2 + self .params .beta ** 2 * ppf ** 2 ) / (2 * ppf * self .params .gamma ** 2 ))
90- )
91- else :
92- return (
93- 1 / np .sqrt (2 * np .pi * ppf )
94- * np .exp (- ((x - self .params .alpha ) ** 2 + self .params .mu ** 2 * ppf ** 2 ) / (2 * ppf ))
95- )
80+ return (1 / abs (self .params .gamma )) * norm .pdf ((x - self .params .alpha - self .params .beta * p )/ abs (self .params .gamma ))
81+ return (1 / abs (self .params .mu )) * norm .pdf ((x - p )/ abs (self .params .mu ))
9682
97- integrator = self .integrator_cls (** (self .integrator_params or {}))
98- result = integrator .compute (integrand )
99- if self .mixture_form == "classical" :
100- val = np .exp (self .params .beta * (x - self .params .alpha ) / self .params .gamma ** 2 ) * result .value
101- else :
102- val = np .exp (self .params .mu * (x - self .params .alpha )) * result .value
103- return val , result .error
83+ res = self .integrator_cls (** (self .integrator_params or {})).compute (integrand )
84+ return res .value , res .error
10485
105- def _compute_logpdf (self , x : float ) -> tuple [float , float ]:
86+ def _compute_logpdf (self , x : float ) -> Tuple [float , float ]:
10687 def integrand (u : float ) -> float :
107- ppf = self .params .distribution .ppf (u )
88+ p = self .params .distribution .ppf (u )
10889 if self .mixture_form == "classical" :
109- return - (
110- (x - self .params .alpha ) ** 2
111- + ppf ** 2 * self .params .beta ** 2
112- + ppf * self .params .gamma ** 2 * np .log (2 * np .pi * ppf * self .params .gamma ** 2 )
113- ) / (2 * ppf * self .params .gamma ** 2 )
114- else :
115- return - ((x - self .params .alpha ) ** 2 + ppf ** 2 * self .params .mu ** 2 + ppf * np .log (2 * np .pi * ppf )) / (2 * ppf )
90+ return np .log (1 / abs (self .params .gamma )) + norm .logpdf ((x - self .params .alpha - self .params .beta * p )/ abs (self .params .gamma ))
91+ return np .log (1 / abs (self .params .mu )) + norm .logpdf ((x - p )/ abs (self .params .mu ))
11692
117- integrator = self .integrator_cls (** (self .integrator_params or {}))
118- result = integrator .compute (integrand )
119- if self .mixture_form == "classical" :
120- val = self .params .beta * (x - self .params .alpha ) / self .params .gamma ** 2 + result .value
121- else :
122- val = self .params .mu * (x - self .params .alpha ) + result .value
123- return val , result .error
124-
125- @lru_cache ()
126- def _integrand_func (self , u : float , d : float , gamma : float ) -> float :
127- ppf = self .params .distribution .ppf (u )
128- return (1 / np .sqrt (2 * np .pi * ppf * abs (gamma ) ** 2 )) * np .exp (- d / (2 * ppf ))
93+ res = self .integrator_cls (** (self .integrator_params or {})).compute (integrand )
94+ return res .value , res .error
0 commit comments