From 87bd31afcac4664e971480b3b8ef6a398f7b7219 Mon Sep 17 00:00:00 2001 From: plidan123 Date: Sat, 10 May 2025 15:38:30 +0300 Subject: [PATCH] feat(generator): unify mixture generators and support canonical forms This commit unifies the sample generators for all mixture types (NMM, NMV, NV) by integrating canonical form support directly into the main generator function. Previously, each mixture type had separate classical and canonical generators. Now, the generator checks mixture_form inside the function and switches behavior accordingly. In canonical form, the beta parameter is not used. This simplifies the generator API and reduces code duplication. - Updated nm_generator.py, nmv_generator.py, nv_generator.py - Remembered `mixture_form` in AbstractMixture class - Updated tests and notebook to use the new unified API BREAKING CHANGE: canonical_generate() methods were removed; use generate() instead. --- .../nm_sigma_estimation_comparison.ipynb | 2 +- src/generators/nm_generator.py | 25 +++-------------- src/generators/nmv_generator.py | 27 +++---------------- src/generators/nv_generator.py | 27 +++---------------- src/mixtures/abstract_mixture.py | 1 + .../nm_generator/test_mixing_normal.py | 20 +++++++------- 6 files changed, 23 insertions(+), 79 deletions(-) diff --git a/jupiter_examples/nm_sigma_estimation_comparison.ipynb b/jupiter_examples/nm_sigma_estimation_comparison.ipynb index f986d13..10f695f 100644 --- a/jupiter_examples/nm_sigma_estimation_comparison.ipynb +++ b/jupiter_examples/nm_sigma_estimation_comparison.ipynb @@ -250,7 +250,7 @@ " \"\"\"\n", " generator = NMGenerator()\n", " mixture = NormalMeanMixtures(\"canonical\", sigma=real_sigma, distribution=distribution)\n", - " return generator.canonical_generate(mixture, sample_len)\n", + " return generator.generate(mixture, sample_len)\n", "\n", "def estimate_sigma_eigenvalue_based(sample, real_sigma, search_area, a, b):\n", " sample_len = len(sample)\n", diff --git a/src/generators/nm_generator.py b/src/generators/nm_generator.py index 4c95b0f..5ecfe4a 100644 --- a/src/generators/nm_generator.py +++ b/src/generators/nm_generator.py @@ -9,7 +9,7 @@ class NMGenerator(AbstractGenerator): @staticmethod - def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: + def generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: """Generate a sample of given size. Classical form of NMM Args: @@ -27,25 +27,6 @@ def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: raise ValueError("Mixture must be NormalMeanMixtures") mixing_values = mixture.params.distribution.rvs(size=size) normal_values = scipy.stats.norm.rvs(size=size) + if mixture.mixture_form == "canonical": + return mixing_values + mixture.params.sigma * normal_values return mixture.params.alpha + mixture.params.beta * mixing_values + mixture.params.gamma * normal_values - - @staticmethod - def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: - """Generate a sample of given size. Canonical form of NMM - - Args: - mixture: Normal Mean Mixture - size: length of sample - - Returns: sample of given size - - Raises: - ValueError: If mixture is not a Normal Mean Mixture - - """ - - if not isinstance(mixture, NormalMeanMixtures): - raise ValueError("Mixture must be NormalMeanMixtures") - mixing_values = mixture.params.distribution.rvs(size=size) - normal_values = scipy.stats.norm.rvs(size=size) - return mixing_values + mixture.params.sigma * normal_values diff --git a/src/generators/nmv_generator.py b/src/generators/nmv_generator.py index 5ff3221..8270a1a 100644 --- a/src/generators/nmv_generator.py +++ b/src/generators/nmv_generator.py @@ -9,7 +9,7 @@ class NMVGenerator(AbstractGenerator): @staticmethod - def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: + def generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: """Generate a sample of given size. Classical form of NMVM Args: @@ -27,29 +27,10 @@ def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: raise ValueError("Mixture must be NormalMeanMixtures") mixing_values = mixture.params.distribution.rvs(size=size) normal_values = scipy.stats.norm.rvs(size=size) + if mixture.mixture_form == "canonical": + return mixture.params.alpha + mixture.params.mu * mixing_values + (mixing_values ** 0.5) * normal_values return ( mixture.params.alpha + mixture.params.beta * mixing_values + mixture.params.gamma * (mixing_values**0.5) * normal_values - ) - - @staticmethod - def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: - """Generate a sample of given size. Canonical form of NMVM - - Args: - mixture: Normal Mean Variance Mixtures - size: length of sample - - Returns: sample of given size - - Raises: - ValueError: If mixture type is not Normal Mean Variance Mixtures - - """ - - if not isinstance(mixture, NormalMeanVarianceMixtures): - raise ValueError("Mixture must be NormalMeanMixtures") - mixing_values = mixture.params.distribution.rvs(size=size) - normal_values = scipy.stats.norm.rvs(size=size) - return mixture.params.alpha + mixture.params.mu * mixing_values + (mixing_values**0.5) * normal_values + ) \ No newline at end of file diff --git a/src/generators/nv_generator.py b/src/generators/nv_generator.py index faa55f5..712e918 100644 --- a/src/generators/nv_generator.py +++ b/src/generators/nv_generator.py @@ -9,7 +9,7 @@ class NVGenerator(AbstractGenerator): @staticmethod - def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: + def generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: """Generate a sample of given size. Classical form of NVM Args: @@ -27,25 +27,6 @@ def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: raise ValueError("Mixture must be NormalMeanMixtures") mixing_values = mixture.params.distribution.rvs(size=size) normal_values = scipy.stats.norm.rvs(size=size) - return mixture.params.alpha + mixture.params.gamma * (mixing_values**0.5) * normal_values - - @staticmethod - def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: - """Generate a sample of given size. Canonical form of NVM - - Args: - mixture: Normal Variance Mixtures - size: length of sample - - Returns: sample of given size - - Raises: - ValueError: If mixture type is not Normal Variance Mixtures - - """ - - if not isinstance(mixture, NormalVarianceMixtures): - raise ValueError("Mixture must be NormalMeanMixtures") - mixing_values = mixture.params.distribution.rvs(size=size) - normal_values = scipy.stats.norm.rvs(size=size) - return mixture.params.alpha + (mixing_values**0.5) * normal_values + if mixture.mixture_form == "canonical": + return mixture.params.alpha + (mixing_values ** 0.5) * normal_values + return mixture.params.alpha + mixture.params.gamma * (mixing_values**0.5) * normal_values \ No newline at end of file diff --git a/src/mixtures/abstract_mixture.py b/src/mixtures/abstract_mixture.py index 13c8e42..77b927c 100644 --- a/src/mixtures/abstract_mixture.py +++ b/src/mixtures/abstract_mixture.py @@ -20,6 +20,7 @@ def __init__(self, mixture_form: str, **kwargs: Any) -> None: mixture_form: Form of Mixture classical or Canonical **kwargs: Parameters of Mixture """ + self.mixture_form = mixture_form if mixture_form == "classical": self.params = self._params_validation(self._classical_collector, kwargs) elif mixture_form == "canonical": diff --git a/tests/generators/nm_generator/test_mixing_normal.py b/tests/generators/nm_generator/test_mixing_normal.py index ea85692..49af0ae 100644 --- a/tests/generators/nm_generator/test_mixing_normal.py +++ b/tests/generators/nm_generator/test_mixing_normal.py @@ -16,7 +16,7 @@ class TestMixingNormal: ) def test_classic_generate_variance_0(self, mixing_variance: float, expected_variance: float) -> None: mixture = NormalMeanMixtures("classical", alpha=0, beta=mixing_variance**0.5, gamma=1, distribution=norm) - sample = self.generator.classical_generate(mixture, self.test_mixture_size) + sample = self.generator.generate(mixture, self.test_mixture_size) actual_variance = ndimage.variance(sample) assert actual_variance == pytest.approx(expected_variance, 0.1) @@ -24,7 +24,7 @@ def test_classic_generate_variance_0(self, mixing_variance: float, expected_vari def test_classic_generate_variance_1(self, beta: float) -> None: expected_variance = beta**2 + 1 mixture = NormalMeanMixtures("classical", alpha=0, beta=beta, gamma=1, distribution=norm) - sample = self.generator.classical_generate(mixture, self.test_mixture_size) + sample = self.generator.generate(mixture, self.test_mixture_size) actual_variance = ndimage.variance(sample) assert actual_variance == pytest.approx(expected_variance, 0.1) @@ -32,7 +32,7 @@ def test_classic_generate_variance_1(self, beta: float) -> None: def test_classic_generate_variance_2(self, beta: float, gamma: float) -> None: expected_variance = beta**2 + gamma**2 mixture = NormalMeanMixtures("classical", alpha=0, beta=beta, gamma=gamma, distribution=norm) - sample = self.generator.classical_generate(mixture, self.test_mixture_size) + sample = self.generator.generate(mixture, self.test_mixture_size) actual_variance = ndimage.variance(sample) assert actual_variance == pytest.approx(expected_variance, 0.1) @@ -40,14 +40,14 @@ def test_classic_generate_variance_2(self, beta: float, gamma: float) -> None: def test_classic_generate_mean(self, beta: float, gamma: float) -> None: expected_mean = 0 mixture = NormalMeanMixtures("classical", alpha=0, beta=beta, gamma=gamma, distribution=norm) - sample = self.generator.classical_generate(mixture, self.test_mixture_size) + sample = self.generator.generate(mixture, self.test_mixture_size) actual_mean = np.mean(np.array(sample)) assert abs(actual_mean - expected_mean) < 1 @pytest.mark.parametrize("expected_size", np.random.randint(0, 100, size=50)) def test_classic_generate_size(self, expected_size: int) -> None: mixture = NormalMeanMixtures("classical", alpha=0, beta=1, gamma=1, distribution=norm) - sample = self.generator.classical_generate(mixture, expected_size) + sample = self.generator.generate(mixture, expected_size) actual_size = np.size(sample) assert actual_size == expected_size @@ -56,7 +56,7 @@ def test_classic_generate_size(self, expected_size: int) -> None: ) def test_canonical_generate_variance_0(self, mixing_variance: float, expected_variance: float) -> None: mixture = NormalMeanMixtures("canonical", sigma=1, distribution=norm(0, mixing_variance**0.5)) - sample = self.generator.canonical_generate(mixture, self.test_mixture_size) + sample = self.generator.generate(mixture, self.test_mixture_size) actual_variance = ndimage.variance(sample) assert actual_variance == pytest.approx(expected_variance, 0.1) @@ -64,7 +64,7 @@ def test_canonical_generate_variance_0(self, mixing_variance: float, expected_va def test_canonical_generate_variance_1(self, sigma: float) -> None: expected_variance = sigma**2 + 1 mixture = NormalMeanMixtures("canonical", sigma=sigma, distribution=norm) - sample = self.generator.canonical_generate(mixture, self.test_mixture_size) + sample = self.generator.generate(mixture, self.test_mixture_size) actual_variance = ndimage.variance(sample) assert actual_variance == pytest.approx(expected_variance, 0.1) @@ -72,7 +72,7 @@ def test_canonical_generate_variance_1(self, sigma: float) -> None: def test_canonical_generate_variance_2(self, mixing_variance: float, sigma: float) -> None: expected_variance = mixing_variance + sigma**2 mixture = NormalMeanMixtures("canonical", sigma=sigma, distribution=norm(0, mixing_variance**0.5)) - sample = self.generator.canonical_generate(mixture, self.test_mixture_size) + sample = self.generator.generate(mixture, self.test_mixture_size) actual_variance = ndimage.variance(sample) assert actual_variance == pytest.approx(expected_variance, 0.1) @@ -80,13 +80,13 @@ def test_canonical_generate_variance_2(self, mixing_variance: float, sigma: floa def test_canonical_generate_mean(self, sigma: float) -> None: expected_mean = 0 mixture = NormalMeanMixtures("canonical", sigma=sigma, distribution=norm) - sample = self.generator.canonical_generate(mixture, self.test_mixture_size) + sample = self.generator.generate(mixture, self.test_mixture_size) actual_mean = np.mean(np.array(sample)) assert abs(actual_mean - expected_mean) < 1 @pytest.mark.parametrize("expected_size", [*np.random.randint(0, 100, size=50), 0, 1, 1000000]) def test_canonical_generate_size(self, expected_size: int) -> None: mixture = NormalMeanMixtures("canonical", sigma=1, distribution=norm) - sample = self.generator.canonical_generate(mixture, expected_size) + sample = self.generator.generate(mixture, expected_size) actual_size = np.size(sample) assert actual_size == expected_size