Skip to content

Commit e0b1aa9

Browse files
committed
fixed
2 parents 0d635b3 + 1cebc58 commit e0b1aa9

File tree

10 files changed

+256
-120
lines changed

10 files changed

+256
-120
lines changed

jupiter_examples/nm_sigma_estimation_comparison.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@
250250
" \"\"\"\n",
251251
" generator = NMGenerator()\n",
252252
" mixture = NormalMeanMixtures(\"canonical\", sigma=real_sigma, distribution=distribution)\n",
253-
" return generator.canonical_generate(mixture, sample_len)\n",
253+
" return generator.generate(mixture, sample_len)\n",
254254
"\n",
255255
"def estimate_sigma_eigenvalue_based(sample, real_sigma, search_area, a, b):\n",
256256
" sample_len = len(sample)\n",

src/generators/nm_generator.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class NMGenerator(AbstractGenerator):
1010

1111
@staticmethod
12-
def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
12+
def generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
1313
"""Generate a sample of given size. Classical form of NMM
1414
1515
Args:
@@ -27,25 +27,6 @@ def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
2727
raise ValueError("Mixture must be NormalMeanMixtures")
2828
mixing_values = mixture.params.distribution.rvs(size=size)
2929
normal_values = scipy.stats.norm.rvs(size=size)
30+
if mixture.mixture_form == "canonical":
31+
return mixing_values + mixture.params.sigma * normal_values
3032
return mixture.params.alpha + mixture.params.beta * mixing_values + mixture.params.gamma * normal_values
31-
32-
@staticmethod
33-
def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
34-
"""Generate a sample of given size. Canonical form of NMM
35-
36-
Args:
37-
mixture: Normal Mean Mixture
38-
size: length of sample
39-
40-
Returns: sample of given size
41-
42-
Raises:
43-
ValueError: If mixture is not a Normal Mean Mixture
44-
45-
"""
46-
47-
if not isinstance(mixture, NormalMeanMixtures):
48-
raise ValueError("Mixture must be NormalMeanMixtures")
49-
mixing_values = mixture.params.distribution.rvs(size=size)
50-
normal_values = scipy.stats.norm.rvs(size=size)
51-
return mixing_values + mixture.params.sigma * normal_values

src/generators/nmv_generator.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class NMVGenerator(AbstractGenerator):
1010

1111
@staticmethod
12-
def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
12+
def generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
1313
"""Generate a sample of given size. Classical form of NMVM
1414
1515
Args:
@@ -27,29 +27,10 @@ def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
2727
raise ValueError("Mixture must be NormalMeanMixtures")
2828
mixing_values = mixture.params.distribution.rvs(size=size)
2929
normal_values = scipy.stats.norm.rvs(size=size)
30+
if mixture.mixture_form == "canonical":
31+
return mixture.params.alpha + mixture.params.mu * mixing_values + (mixing_values ** 0.5) * normal_values
3032
return (
3133
mixture.params.alpha
3234
+ mixture.params.beta * mixing_values
3335
+ mixture.params.gamma * (mixing_values**0.5) * normal_values
34-
)
35-
36-
@staticmethod
37-
def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
38-
"""Generate a sample of given size. Canonical form of NMVM
39-
40-
Args:
41-
mixture: Normal Mean Variance Mixtures
42-
size: length of sample
43-
44-
Returns: sample of given size
45-
46-
Raises:
47-
ValueError: If mixture type is not Normal Mean Variance Mixtures
48-
49-
"""
50-
51-
if not isinstance(mixture, NormalMeanVarianceMixtures):
52-
raise ValueError("Mixture must be NormalMeanMixtures")
53-
mixing_values = mixture.params.distribution.rvs(size=size)
54-
normal_values = scipy.stats.norm.rvs(size=size)
55-
return mixture.params.alpha + mixture.params.mu * mixing_values + (mixing_values**0.5) * normal_values
36+
)

src/generators/nv_generator.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class NVGenerator(AbstractGenerator):
1010

1111
@staticmethod
12-
def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
12+
def generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
1313
"""Generate a sample of given size. Classical form of NVM
1414
1515
Args:
@@ -27,25 +27,6 @@ def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
2727
raise ValueError("Mixture must be NormalMeanMixtures")
2828
mixing_values = mixture.params.distribution.rvs(size=size)
2929
normal_values = scipy.stats.norm.rvs(size=size)
30-
return mixture.params.alpha + mixture.params.gamma * (mixing_values**0.5) * normal_values
31-
32-
@staticmethod
33-
def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
34-
"""Generate a sample of given size. Canonical form of NVM
35-
36-
Args:
37-
mixture: Normal Variance Mixtures
38-
size: length of sample
39-
40-
Returns: sample of given size
41-
42-
Raises:
43-
ValueError: If mixture type is not Normal Variance Mixtures
44-
45-
"""
46-
47-
if not isinstance(mixture, NormalVarianceMixtures):
48-
raise ValueError("Mixture must be NormalMeanMixtures")
49-
mixing_values = mixture.params.distribution.rvs(size=size)
50-
normal_values = scipy.stats.norm.rvs(size=size)
51-
return mixture.params.alpha + (mixing_values**0.5) * normal_values
30+
if mixture.mixture_form == "canonical":
31+
return mixture.params.alpha + (mixing_values ** 0.5) * normal_values
32+
return mixture.params.alpha + mixture.params.gamma * (mixing_values**0.5) * normal_values

src/mixtures/abstract_mixture.py

Lines changed: 95 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,39 @@
11
from abc import ABCMeta, abstractmethod
22
from dataclasses import fields
3-
from typing import Any
3+
from typing import Any, List, Tuple, Union, Dict, Type
4+
import numpy as np
5+
from numpy.typing import NDArray
46

57
from scipy.stats import rv_continuous
68
from scipy.stats.distributions import rv_frozen
79

10+
from src.algorithms.support_algorithms.integrator import Integrator
11+
from src.algorithms.support_algorithms.rqmc import RQMCIntegrator # default integrator
812

913
class AbstractMixtures(metaclass=ABCMeta):
1014
"""Base class for Mixtures"""
1115

1216
_classical_collector: Any
1317
_canonical_collector: Any
1418

15-
@abstractmethod
16-
def __init__(self, mixture_form: str, **kwargs: Any) -> None:
19+
def __init__(
20+
self,
21+
mixture_form: str,
22+
integrator_cls: Type[Integrator] = RQMCIntegrator,
23+
integrator_params: Dict[str, Any] = None,
24+
**kwargs: Any
25+
) -> None:
1726
"""
18-
1927
Args:
20-
mixture_form: Form of Mixture classical or Canonical
21-
**kwargs: Parameters of Mixture
28+
mixture_form: Form of Mixture classical or canonical
29+
integrator_cls: Class implementing Integrator protocol (default: RQMCIntegrator)
30+
integrator_params: Parameters for integrator constructor (default: {{}})
31+
**kwargs: Parameters of Mixture (alpha, gamma, etc.)
2232
"""
33+
self.mixture_form = mixture_form
34+
self.integrator_cls = integrator_cls
35+
self.integrator_params = integrator_params or {}
36+
2337
if mixture_form == "classical":
2438
self.params = self._params_validation(self._classical_collector, kwargs)
2539
elif mixture_form == "canonical":
@@ -28,40 +42,88 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
2842
raise AssertionError(f"Unknown mixture form: {mixture_form}")
2943

3044
@abstractmethod
31-
def compute_moment(self, n: int, params: dict) -> tuple[float, float]: ...
45+
def _compute_moment(self, n: int) -> Tuple[float, float]:
46+
...
47+
48+
def compute_moment(
49+
self,
50+
x: Union[List[int], int, NDArray[np.float64]]
51+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
52+
if isinstance(x, np.ndarray):
53+
return np.array([self._compute_moment(xp) for xp in x], dtype=object)
54+
elif isinstance(x, list):
55+
return [self._compute_moment(xp) for xp in x]
56+
elif isinstance(x, int):
57+
return self._compute_moment(x)
58+
else:
59+
raise TypeError(f"Unsupported type for x: {type(x)}")
3260

3361
@abstractmethod
34-
def compute_cdf(self, x: float, params: dict) -> tuple[float, float]: ...
62+
def _compute_pdf(self, x: float) -> Tuple[float, float]:
63+
...
64+
65+
def compute_pdf(
66+
self,
67+
x: Union[List[float], float, NDArray[np.float64]]
68+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
69+
if isinstance(x, np.ndarray):
70+
return np.array([self._compute_pdf(xp) for xp in x], dtype=object)
71+
elif isinstance(x, list):
72+
return [self._compute_pdf(xp) for xp in x]
73+
elif isinstance(x, float):
74+
return self._compute_pdf(x)
75+
else:
76+
raise TypeError(f"Unsupported type for x: {type(x)}")
3577

3678
@abstractmethod
37-
def compute_pdf(self, x: float, params: dict) -> tuple[float, float]: ...
79+
def _compute_logpdf(self, x: float) -> Tuple[float, float]:
80+
...
81+
82+
def compute_logpdf(
83+
self,
84+
x: Union[List[float], float, NDArray[np.float64]]
85+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
86+
if isinstance(x, np.ndarray):
87+
return np.array([self._compute_logpdf(xp) for xp in x], dtype=object)
88+
elif isinstance(x, list):
89+
return [self._compute_logpdf(xp) for xp in x]
90+
elif isinstance(x, float):
91+
return self._compute_logpdf(x)
92+
else:
93+
raise TypeError(f"Unsupported type for x: {type(x)}")
3894

3995
@abstractmethod
40-
def compute_logpdf(self, x: float, params: dict) -> tuple[float, float]: ...
41-
42-
def _params_validation(self, data_collector: Any, params: dict[str, float | rv_continuous | rv_frozen]) -> Any:
43-
"""Mixture Parameters Validation
44-
45-
Args:
46-
data_collector: Dataclass that collect parameters of Mixture
47-
params: Input parameters
48-
49-
Returns: Instance of dataclass
50-
51-
Raises:
52-
ValueError: If given parameters is unexpected
53-
ValueError: If parameter type is invalid
54-
ValueError: If parameters age not given
55-
56-
"""
57-
96+
def _compute_cdf(self, x: float) -> Tuple[float, float]:
97+
...
98+
99+
def compute_cdf(
100+
self,
101+
x: Union[List[float], float, NDArray[np.float64]]
102+
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
103+
if isinstance(x, np.ndarray):
104+
return np.array([self._compute_cdf(xp) for xp in x], dtype=object)
105+
elif isinstance(x, list):
106+
return [self._compute_cdf(xp) for xp in x]
107+
elif isinstance(x, float):
108+
return self._compute_cdf(x)
109+
else:
110+
raise TypeError(f"Unsupported type for x: {type(x)}")
111+
112+
def _params_validation(
113+
self,
114+
data_collector: Any,
115+
params: dict[str, float | rv_continuous | rv_frozen]
116+
) -> Any:
117+
"""Mixture Parameters Validation"""
58118
dataclass_fields = fields(data_collector)
59119
if len(params) != len(dataclass_fields):
60120
raise ValueError(f"Expected {len(dataclass_fields)} arguments, got {len(params)}")
61-
names_and_types = dict((field.name, field.type) for field in dataclass_fields)
62-
for pair in params.items():
63-
if pair[0] not in names_and_types:
64-
raise ValueError(f"Unexpected parameter {pair[0]}")
65-
if not isinstance(pair[1], names_and_types[pair[0]]):
66-
raise ValueError(f"Type missmatch: {pair[0]} should be {names_and_types[pair[0]]}, not {type(pair[1])}")
121+
names_and_types = {field.name: field.type for field in dataclass_fields}
122+
for name, value in params.items():
123+
if name not in names_and_types:
124+
raise ValueError(f"Unexpected parameter {name}")
125+
if not isinstance(value, names_and_types[name]):
126+
raise ValueError(
127+
f"Type mismatch: {name} should be {names_and_types[name]}, not {type(value)}"
128+
)
67129
return data_collector(**params)

src/procedures/semiparametric/nvm_semiparametric/g_estimation_given_mu_rqmc_based.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def v_sequence_default_value(n: float) -> float:
3232
INTEGRATION_LIMIT_DEFAULT_VALUE: int = 50
3333

3434

35-
class NMVEstimationDensityInvMTquadRQMCBased:
35+
class SemiParametricGEstimationGivenMuRQMCBased:
3636
"""Estimation of mixing density function g (xi density function) of NVM mixture represented in canonical form Y =
3737
alpha + mu*xi + sqrt(xi)*N, where alpha = 0 and mu is given.
3838
@@ -52,13 +52,13 @@ class ParamsAnnotation(TypedDict, total=False):
5252
integration_tolerance: float
5353
integration_limit: int
5454

55-
def __init__(self, sample: Optional[np.ndarray] = None, **kwargs: Unpack[ParamsAnnotation]):
55+
def __init__(self, sample: Optional[_typing.NDArray[np.float64]] = None, **kwargs: Unpack[ParamsAnnotation]):
5656
self.x_powers: Dict[float, np.ndarray] = {}
5757
self.second_u_integrals: np.ndarray
5858
self.first_u_integrals: np.ndarray
5959
self.gamma_grid: np.ndarray
6060
self.v_grid: np.ndarray
61-
self.sample: np.ndarray = np.array([]) if sample is None else sample
61+
self.sample: _typing.NDArray[np.float64] = np.array([]) if sample is None else sample
6262
self.n: int = len(self.sample)
6363
(
6464
self.mu,
@@ -171,6 +171,6 @@ def compute_integrals_for_x(self, x: float) -> float:
171171
total = (first_integral + second_integral) / self.denominator
172172
return max(0.0, total.real)
173173

174-
def compute(self, sample: np.ndarray) -> EstimateResult:
174+
def algorithm(self, sample: np._typing.NDArray) -> EstimateResult:
175175
y_data = [self.compute_integrals_for_x(x) for x in self.x_data]
176176
return EstimateResult(list_value=y_data, success=True)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Protocol, Callable, Optional
3+
4+
5+
@dataclass
6+
class IntegrationResult:
7+
value: float
8+
error: float
9+
message: Optional[dict[str, Any]] | None = None
10+
11+
12+
class Integrator(Protocol):
13+
14+
"""Base class for integral calculation"""
15+
16+
def __init__(self) -> None:
17+
...
18+
19+
def compute(self, func: Callable) -> IntegrationResult:
20+
...

0 commit comments

Comments
 (0)