Skip to content

Commit 8ec30a0

Browse files
authored
Merge pull request #38 from PySATL/feat/unify-mixture-generators
feat: unify mixture generators and support canonical forms
2 parents 476758e + 87bd31a commit 8ec30a0

File tree

6 files changed

+23
-79
lines changed

6 files changed

+23
-79
lines changed

jupiter_examples/nm_sigma_estimation_comparison.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@
250250
" \"\"\"\n",
251251
" generator = NMGenerator()\n",
252252
" mixture = NormalMeanMixtures(\"canonical\", sigma=real_sigma, distribution=distribution)\n",
253-
" return generator.canonical_generate(mixture, sample_len)\n",
253+
" return generator.generate(mixture, sample_len)\n",
254254
"\n",
255255
"def estimate_sigma_eigenvalue_based(sample, real_sigma, search_area, a, b):\n",
256256
" sample_len = len(sample)\n",

src/generators/nm_generator.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class NMGenerator(AbstractGenerator):
1010

1111
@staticmethod
12-
def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
12+
def generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
1313
"""Generate a sample of given size. Classical form of NMM
1414
1515
Args:
@@ -27,25 +27,6 @@ def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
2727
raise ValueError("Mixture must be NormalMeanMixtures")
2828
mixing_values = mixture.params.distribution.rvs(size=size)
2929
normal_values = scipy.stats.norm.rvs(size=size)
30+
if mixture.mixture_form == "canonical":
31+
return mixing_values + mixture.params.sigma * normal_values
3032
return mixture.params.alpha + mixture.params.beta * mixing_values + mixture.params.gamma * normal_values
31-
32-
@staticmethod
33-
def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
34-
"""Generate a sample of given size. Canonical form of NMM
35-
36-
Args:
37-
mixture: Normal Mean Mixture
38-
size: length of sample
39-
40-
Returns: sample of given size
41-
42-
Raises:
43-
ValueError: If mixture is not a Normal Mean Mixture
44-
45-
"""
46-
47-
if not isinstance(mixture, NormalMeanMixtures):
48-
raise ValueError("Mixture must be NormalMeanMixtures")
49-
mixing_values = mixture.params.distribution.rvs(size=size)
50-
normal_values = scipy.stats.norm.rvs(size=size)
51-
return mixing_values + mixture.params.sigma * normal_values

src/generators/nmv_generator.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class NMVGenerator(AbstractGenerator):
1010

1111
@staticmethod
12-
def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
12+
def generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
1313
"""Generate a sample of given size. Classical form of NMVM
1414
1515
Args:
@@ -27,29 +27,10 @@ def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
2727
raise ValueError("Mixture must be NormalMeanMixtures")
2828
mixing_values = mixture.params.distribution.rvs(size=size)
2929
normal_values = scipy.stats.norm.rvs(size=size)
30+
if mixture.mixture_form == "canonical":
31+
return mixture.params.alpha + mixture.params.mu * mixing_values + (mixing_values ** 0.5) * normal_values
3032
return (
3133
mixture.params.alpha
3234
+ mixture.params.beta * mixing_values
3335
+ mixture.params.gamma * (mixing_values**0.5) * normal_values
34-
)
35-
36-
@staticmethod
37-
def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
38-
"""Generate a sample of given size. Canonical form of NMVM
39-
40-
Args:
41-
mixture: Normal Mean Variance Mixtures
42-
size: length of sample
43-
44-
Returns: sample of given size
45-
46-
Raises:
47-
ValueError: If mixture type is not Normal Mean Variance Mixtures
48-
49-
"""
50-
51-
if not isinstance(mixture, NormalMeanVarianceMixtures):
52-
raise ValueError("Mixture must be NormalMeanMixtures")
53-
mixing_values = mixture.params.distribution.rvs(size=size)
54-
normal_values = scipy.stats.norm.rvs(size=size)
55-
return mixture.params.alpha + mixture.params.mu * mixing_values + (mixing_values**0.5) * normal_values
36+
)

src/generators/nv_generator.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class NVGenerator(AbstractGenerator):
1010

1111
@staticmethod
12-
def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
12+
def generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
1313
"""Generate a sample of given size. Classical form of NVM
1414
1515
Args:
@@ -27,25 +27,6 @@ def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
2727
raise ValueError("Mixture must be NormalMeanMixtures")
2828
mixing_values = mixture.params.distribution.rvs(size=size)
2929
normal_values = scipy.stats.norm.rvs(size=size)
30-
return mixture.params.alpha + mixture.params.gamma * (mixing_values**0.5) * normal_values
31-
32-
@staticmethod
33-
def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
34-
"""Generate a sample of given size. Canonical form of NVM
35-
36-
Args:
37-
mixture: Normal Variance Mixtures
38-
size: length of sample
39-
40-
Returns: sample of given size
41-
42-
Raises:
43-
ValueError: If mixture type is not Normal Variance Mixtures
44-
45-
"""
46-
47-
if not isinstance(mixture, NormalVarianceMixtures):
48-
raise ValueError("Mixture must be NormalMeanMixtures")
49-
mixing_values = mixture.params.distribution.rvs(size=size)
50-
normal_values = scipy.stats.norm.rvs(size=size)
51-
return mixture.params.alpha + (mixing_values**0.5) * normal_values
30+
if mixture.mixture_form == "canonical":
31+
return mixture.params.alpha + (mixing_values ** 0.5) * normal_values
32+
return mixture.params.alpha + mixture.params.gamma * (mixing_values**0.5) * normal_values

src/mixtures/abstract_mixture.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None:
2121
mixture_form: Form of Mixture classical or Canonical
2222
**kwargs: Parameters of Mixture
2323
"""
24+
self.mixture_form = mixture_form
2425
if mixture_form == "classical":
2526
self.params = self._params_validation(self._classical_collector, kwargs)
2627
elif mixture_form == "canonical":

tests/generators/nm_generator/test_mixing_normal.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,38 @@ class TestMixingNormal:
1616
)
1717
def test_classic_generate_variance_0(self, mixing_variance: float, expected_variance: float) -> None:
1818
mixture = NormalMeanMixtures("classical", alpha=0, beta=mixing_variance**0.5, gamma=1, distribution=norm)
19-
sample = self.generator.classical_generate(mixture, self.test_mixture_size)
19+
sample = self.generator.generate(mixture, self.test_mixture_size)
2020
actual_variance = ndimage.variance(sample)
2121
assert actual_variance == pytest.approx(expected_variance, 0.1)
2222

2323
@pytest.mark.parametrize("beta", np.random.uniform(0, 100, size=50))
2424
def test_classic_generate_variance_1(self, beta: float) -> None:
2525
expected_variance = beta**2 + 1
2626
mixture = NormalMeanMixtures("classical", alpha=0, beta=beta, gamma=1, distribution=norm)
27-
sample = self.generator.classical_generate(mixture, self.test_mixture_size)
27+
sample = self.generator.generate(mixture, self.test_mixture_size)
2828
actual_variance = ndimage.variance(sample)
2929
assert actual_variance == pytest.approx(expected_variance, 0.1)
3030

3131
@pytest.mark.parametrize("beta, gamma", np.random.uniform(0, 100, size=(50, 2)))
3232
def test_classic_generate_variance_2(self, beta: float, gamma: float) -> None:
3333
expected_variance = beta**2 + gamma**2
3434
mixture = NormalMeanMixtures("classical", alpha=0, beta=beta, gamma=gamma, distribution=norm)
35-
sample = self.generator.classical_generate(mixture, self.test_mixture_size)
35+
sample = self.generator.generate(mixture, self.test_mixture_size)
3636
actual_variance = ndimage.variance(sample)
3737
assert actual_variance == pytest.approx(expected_variance, 0.1)
3838

3939
@pytest.mark.parametrize("beta, gamma", np.random.uniform(0, 10, size=(50, 2)))
4040
def test_classic_generate_mean(self, beta: float, gamma: float) -> None:
4141
expected_mean = 0
4242
mixture = NormalMeanMixtures("classical", alpha=0, beta=beta, gamma=gamma, distribution=norm)
43-
sample = self.generator.classical_generate(mixture, self.test_mixture_size)
43+
sample = self.generator.generate(mixture, self.test_mixture_size)
4444
actual_mean = np.mean(np.array(sample))
4545
assert abs(actual_mean - expected_mean) < 1
4646

4747
@pytest.mark.parametrize("expected_size", np.random.randint(0, 100, size=50))
4848
def test_classic_generate_size(self, expected_size: int) -> None:
4949
mixture = NormalMeanMixtures("classical", alpha=0, beta=1, gamma=1, distribution=norm)
50-
sample = self.generator.classical_generate(mixture, expected_size)
50+
sample = self.generator.generate(mixture, expected_size)
5151
actual_size = np.size(sample)
5252
assert actual_size == expected_size
5353

@@ -56,37 +56,37 @@ def test_classic_generate_size(self, expected_size: int) -> None:
5656
)
5757
def test_canonical_generate_variance_0(self, mixing_variance: float, expected_variance: float) -> None:
5858
mixture = NormalMeanMixtures("canonical", sigma=1, distribution=norm(0, mixing_variance**0.5))
59-
sample = self.generator.canonical_generate(mixture, self.test_mixture_size)
59+
sample = self.generator.generate(mixture, self.test_mixture_size)
6060
actual_variance = ndimage.variance(sample)
6161
assert actual_variance == pytest.approx(expected_variance, 0.1)
6262

6363
@pytest.mark.parametrize("sigma", np.random.uniform(0, 100, size=50))
6464
def test_canonical_generate_variance_1(self, sigma: float) -> None:
6565
expected_variance = sigma**2 + 1
6666
mixture = NormalMeanMixtures("canonical", sigma=sigma, distribution=norm)
67-
sample = self.generator.canonical_generate(mixture, self.test_mixture_size)
67+
sample = self.generator.generate(mixture, self.test_mixture_size)
6868
actual_variance = ndimage.variance(sample)
6969
assert actual_variance == pytest.approx(expected_variance, 0.1)
7070

7171
@pytest.mark.parametrize("mixing_variance, sigma", np.random.uniform(0, 100, size=(50, 2)))
7272
def test_canonical_generate_variance_2(self, mixing_variance: float, sigma: float) -> None:
7373
expected_variance = mixing_variance + sigma**2
7474
mixture = NormalMeanMixtures("canonical", sigma=sigma, distribution=norm(0, mixing_variance**0.5))
75-
sample = self.generator.canonical_generate(mixture, self.test_mixture_size)
75+
sample = self.generator.generate(mixture, self.test_mixture_size)
7676
actual_variance = ndimage.variance(sample)
7777
assert actual_variance == pytest.approx(expected_variance, 0.1)
7878

7979
@pytest.mark.parametrize("sigma", np.random.uniform(0, 10, size=50))
8080
def test_canonical_generate_mean(self, sigma: float) -> None:
8181
expected_mean = 0
8282
mixture = NormalMeanMixtures("canonical", sigma=sigma, distribution=norm)
83-
sample = self.generator.canonical_generate(mixture, self.test_mixture_size)
83+
sample = self.generator.generate(mixture, self.test_mixture_size)
8484
actual_mean = np.mean(np.array(sample))
8585
assert abs(actual_mean - expected_mean) < 1
8686

8787
@pytest.mark.parametrize("expected_size", [*np.random.randint(0, 100, size=50), 0, 1, 1000000])
8888
def test_canonical_generate_size(self, expected_size: int) -> None:
8989
mixture = NormalMeanMixtures("canonical", sigma=1, distribution=norm)
90-
sample = self.generator.canonical_generate(mixture, expected_size)
90+
sample = self.generator.generate(mixture, expected_size)
9191
actual_size = np.size(sample)
9292
assert actual_size == expected_size

0 commit comments

Comments
 (0)