|
| 1 | +""" |
| 2 | +Tests for Normal Distribution Family Configuration |
| 3 | +
|
| 4 | +This module tests the functionality of the normal distribution family |
| 5 | +defined in config.py, including parameterizations, characteristics, |
| 6 | +and sampling. |
| 7 | +""" |
| 8 | + |
| 9 | +import math |
| 10 | +from typing import cast |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import pytest |
| 14 | +from scipy.stats import norm |
| 15 | + |
| 16 | +from pysatl_core.distributions.characteristics import GenericCharacteristic |
| 17 | +from pysatl_core.families import ParametricFamilyRegister, configure_family_register |
| 18 | +from pysatl_core.families.config import ( |
| 19 | + ExpParametrization, |
| 20 | + MeanPrecParametrization, |
| 21 | + MeanVarParametrization, |
| 22 | +) |
| 23 | + |
| 24 | +# Import PySATL components |
| 25 | +from pysatl_core.families.registry import _reset_families_register_for_tests |
| 26 | +from pysatl_core.types import UnivariateContinuous |
| 27 | + |
| 28 | +__author__ = "Fedor Myznikov" |
| 29 | +__copyright__ = "Copyright (c) 2025 PySATL project" |
| 30 | +__license__ = "SPDX-License-Identifier: MIT" |
| 31 | + |
| 32 | + |
| 33 | +class TestNormalFamily: |
| 34 | + """Test suite for Normal distribution family.""" |
| 35 | + |
| 36 | + def setup_method(self): |
| 37 | + """Setup before each test method.""" |
| 38 | + _reset_families_register_for_tests() |
| 39 | + configure_family_register() |
| 40 | + self.normal_family = ParametricFamilyRegister.get("Normal Family") |
| 41 | + |
| 42 | + def test_family_registration(self): |
| 43 | + """Test that normal family is properly registered.""" |
| 44 | + family = ParametricFamilyRegister.get("Normal Family") |
| 45 | + assert family.name == "Normal Family" |
| 46 | + |
| 47 | + # Check parameterizations |
| 48 | + expected_parametrizations = {"meanVar", "meanPrec", "exponential"} |
| 49 | + assert set(family.parametrization_names) == expected_parametrizations |
| 50 | + assert family.base_parametrization_name == "meanVar" |
| 51 | + |
| 52 | + def test_mean_var_parametrization_creation(self): |
| 53 | + """Test creation of distribution with mean-variance parametrization.""" |
| 54 | + dist = self.normal_family(mu=2.0, sigma=1.5) |
| 55 | + |
| 56 | + assert dist.distr_name == "Normal Family" |
| 57 | + assert dist.distribution_type == UnivariateContinuous |
| 58 | + |
| 59 | + # Приводим к конкретному типу параметризации |
| 60 | + params = cast(MeanVarParametrization, dist.parameters) |
| 61 | + assert params.mu == 2.0 |
| 62 | + assert params.sigma == 1.5 |
| 63 | + assert params.name == "meanVar" |
| 64 | + |
| 65 | + def test_mean_prec_parametrization_creation(self): |
| 66 | + """Test creation of distribution with mean-precision parametrization.""" |
| 67 | + dist = self.normal_family(mu=2.0, tau=0.25, parametrization_name="meanPrec") |
| 68 | + |
| 69 | + # Приводим к конкретному типу параметризации |
| 70 | + params = cast(MeanPrecParametrization, dist.parameters) |
| 71 | + assert params.mu == 2.0 |
| 72 | + assert params.tau == 0.25 |
| 73 | + assert params.name == "meanPrec" |
| 74 | + |
| 75 | + def test_exponential_parametrization_creation(self): |
| 76 | + """Test creation of distribution with exponential parametrization.""" |
| 77 | + # For N(2, 1.5): a = -1/(2*1.5²) = -0.222..., b = 2/1.5² = 0.888... |
| 78 | + dist = self.normal_family(a=-0.222, b=0.888, parametrization_name="exponential") |
| 79 | + |
| 80 | + # Приводим к конкретному типу параметризации |
| 81 | + params = cast(ExpParametrization, dist.parameters) |
| 82 | + assert params.a == -0.222 |
| 83 | + assert params.b == 0.888 |
| 84 | + assert params.name == "exponential" |
| 85 | + |
| 86 | + def test_parametrization_constraints(self): |
| 87 | + """Test parameter constraints validation.""" |
| 88 | + # Sigma must be positive |
| 89 | + with pytest.raises(ValueError, match="sigma > 0"): |
| 90 | + self.normal_family(mu=0, sigma=-1.0) |
| 91 | + |
| 92 | + # Tau must be positive |
| 93 | + with pytest.raises(ValueError, match="tau > 0"): |
| 94 | + self.normal_family(mu=0, tau=-1.0, parametrization_name="meanPrec") |
| 95 | + |
| 96 | + # a must be negative |
| 97 | + with pytest.raises(ValueError, match="a < 0"): |
| 98 | + self.normal_family(a=1.0, b=0.0, parametrization_name="exponential") |
| 99 | + |
| 100 | + def test_pdf_calculation(self): |
| 101 | + """Test PDF calculation against scipy.stats.norm.""" |
| 102 | + dist = self.normal_family(mu=2.0, sigma=1.5) |
| 103 | + pdf = dist.computation_strategy.query_method("pdf", dist) |
| 104 | + |
| 105 | + # Test points |
| 106 | + test_points = [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0] |
| 107 | + |
| 108 | + for x in test_points: |
| 109 | + # Our implementation |
| 110 | + our_pdf = pdf(x) |
| 111 | + # Scipy reference |
| 112 | + scipy_pdf = norm.pdf(x, loc=2.0, scale=1.5) |
| 113 | + |
| 114 | + assert abs(our_pdf - scipy_pdf) < 1e-10 |
| 115 | + |
| 116 | + def test_cdf_calculation(self): |
| 117 | + """Test CDF calculation against scipy.stats.norm.""" |
| 118 | + dist = self.normal_family(mu=2.0, sigma=1.5) |
| 119 | + cdf = dist.computation_strategy.query_method("cdf", dist) |
| 120 | + |
| 121 | + test_points = [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0] |
| 122 | + |
| 123 | + for x in test_points: |
| 124 | + our_cdf = cdf(x) |
| 125 | + scipy_cdf = norm.cdf(x, loc=2.0, scale=1.5) |
| 126 | + |
| 127 | + assert abs(our_cdf - scipy_cdf) < 1e-10 |
| 128 | + |
| 129 | + def test_ppf_calculation(self): |
| 130 | + """Test PPF calculation against scipy.stats.norm.""" |
| 131 | + dist = self.normal_family(mu=2.0, sigma=1.5) |
| 132 | + ppf = dist.computation_strategy.query_method("ppf", dist) |
| 133 | + |
| 134 | + test_probabilities = [0.001, 0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99, 0.999] |
| 135 | + |
| 136 | + for p in test_probabilities: |
| 137 | + our_ppf = ppf(p) |
| 138 | + scipy_ppf = norm.ppf(p, loc=2.0, scale=1.5) |
| 139 | + |
| 140 | + assert abs(our_ppf - scipy_ppf) < 1e-10 |
| 141 | + |
| 142 | + def test_characteristic_function(self): |
| 143 | + """Test characteristic function calculation.""" |
| 144 | + dist = self.normal_family(mu=2.0, sigma=1.5) |
| 145 | + char_func = dist.computation_strategy.query_method("char_func", dist) |
| 146 | + |
| 147 | + test_points = [-2.0, -1.0, 0.0, 1.0, 2.0] |
| 148 | + |
| 149 | + for t in test_points: |
| 150 | + cf_value = char_func(t) |
| 151 | + |
| 152 | + # Characteristic function of N(μ, σ²) is exp(iμt - ½σ²t²) |
| 153 | + expected_real = math.exp(-0.5 * (1.5 * t) ** 2) * math.cos(2.0 * t) |
| 154 | + expected_imag = math.exp(-0.5 * (1.5 * t) ** 2) * math.sin(2.0 * t) |
| 155 | + |
| 156 | + assert abs(cf_value.real - expected_real) < 1e-10 |
| 157 | + assert abs(cf_value.imag - expected_imag) < 1e-10 |
| 158 | + |
| 159 | + def test_moments(self): |
| 160 | + """Test moment calculations.""" |
| 161 | + dist = self.normal_family(mu=2.0, sigma=1.5) |
| 162 | + |
| 163 | + mean_func = dist.computation_strategy.query_method("mean", dist) |
| 164 | + var_func = dist.computation_strategy.query_method("var", dist) |
| 165 | + skew_func = dist.computation_strategy.query_method("skewness", dist) |
| 166 | + kurt_func = dist.computation_strategy.query_method("excess_kurtosis", dist) |
| 167 | + |
| 168 | + our_mean = mean_func(None) |
| 169 | + our_var = var_func(None) |
| 170 | + our_skew = skew_func(None) |
| 171 | + our_kurt = kurt_func(None) |
| 172 | + |
| 173 | + assert abs(our_mean - 2.0) < 1e-10 |
| 174 | + assert abs(our_var - 2.25) < 1e-10 |
| 175 | + assert abs(our_skew - 0.0) < 1e-10 |
| 176 | + assert abs(our_kurt - 0.0) < 1e-10 |
| 177 | + |
| 178 | + def test_parametrization_conversions(self): |
| 179 | + """Test conversions between different parameterizations.""" |
| 180 | + # Create with mean-variance |
| 181 | + dist_mv = self.normal_family(mu=2.0, sigma=1.5) |
| 182 | + |
| 183 | + # Convert to base should return same for meanVar |
| 184 | + base_params = self.normal_family.to_base(dist_mv.parameters) |
| 185 | + base_params = cast(MeanVarParametrization, base_params) |
| 186 | + assert base_params.mu == 2.0 |
| 187 | + assert base_params.sigma == 1.5 |
| 188 | + |
| 189 | + # Test meanPrec conversion |
| 190 | + dist_mp = self.normal_family(mu=2.0, tau=0.25, parametrization_name="meanPrec") |
| 191 | + base_from_mp = self.normal_family.to_base(dist_mp.parameters) |
| 192 | + base_from_mp = cast(MeanVarParametrization, base_from_mp) |
| 193 | + assert abs(base_from_mp.mu - 2.0) < 1e-10 |
| 194 | + assert abs(base_from_mp.sigma - 2.0) < 1e-10 # sigma = 1/sqrt(tau) = 1/sqrt(0.25) = 2 |
| 195 | + |
| 196 | + # Test exponential conversion |
| 197 | + dist_exp = self.normal_family(a=-0.222, b=0.888, parametrization_name="exponential") |
| 198 | + base_from_exp = self.normal_family.to_base(dist_exp.parameters) |
| 199 | + base_from_exp = cast(MeanVarParametrization, base_from_exp) |
| 200 | + # Should be approximately N(2, 1.5) |
| 201 | + assert abs(base_from_exp.mu - 2.0) < 0.1 |
| 202 | + assert abs(base_from_exp.sigma - 1.5) < 0.1 |
| 203 | + |
| 204 | + def test_analytical_computations_caching(self): |
| 205 | + """Test that analytical computations are properly cached.""" |
| 206 | + dist = self.normal_family(mu=0.0, sigma=1.0) |
| 207 | + |
| 208 | + # Access analytical computations multiple times |
| 209 | + comp1 = dist.analytical_computations |
| 210 | + comp2 = dist.analytical_computations |
| 211 | + |
| 212 | + # Should be the same object (cached) |
| 213 | + assert comp1 is comp2 |
| 214 | + |
| 215 | + # Should contain expected characteristics |
| 216 | + expected_chars = { |
| 217 | + "pdf", |
| 218 | + "cdf", |
| 219 | + "ppf", |
| 220 | + "char_func", |
| 221 | + "mean", |
| 222 | + "var", |
| 223 | + "skewness", |
| 224 | + "raw_kurtosis", |
| 225 | + "excess_kurtosis", |
| 226 | + } |
| 227 | + assert set(comp1.keys()) == expected_chars |
| 228 | + |
| 229 | + def test_array_input_support(self): |
| 230 | + """Test that characteristics support array inputs.""" |
| 231 | + dist = self.normal_family(mu=0.0, sigma=1.0) |
| 232 | + |
| 233 | + # Test with numpy array input |
| 234 | + x_array = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) |
| 235 | + |
| 236 | + pdf = dist.computation_strategy.query_method("pdf", dist) |
| 237 | + cdf = dist.computation_strategy.query_method("cdf", dist) |
| 238 | + |
| 239 | + pdf_array = pdf(x_array) |
| 240 | + cdf_array = cdf(x_array) |
| 241 | + |
| 242 | + # Results should be arrays of same shape |
| 243 | + assert pdf_array.shape == x_array.shape |
| 244 | + assert cdf_array.shape == x_array.shape |
| 245 | + |
| 246 | + # Compare with scipy |
| 247 | + scipy_pdf = norm.pdf(x_array) |
| 248 | + scipy_cdf = norm.cdf(x_array) |
| 249 | + |
| 250 | + np.testing.assert_array_almost_equal(pdf_array, scipy_pdf, decimal=10) |
| 251 | + np.testing.assert_array_almost_equal(cdf_array, scipy_cdf, decimal=10) |
| 252 | + |
| 253 | + |
| 254 | +class TestNormalFamilyEdgeCases: |
| 255 | + """Test edge cases and error conditions.""" |
| 256 | + |
| 257 | + def setup_method(self): |
| 258 | + """Setup before each test method.""" |
| 259 | + _reset_families_register_for_tests() |
| 260 | + configure_family_register() |
| 261 | + self.normal_family = ParametricFamilyRegister.get("Normal Family") |
| 262 | + |
| 263 | + def test_invalid_parameterization(self): |
| 264 | + """Test error for invalid parameterization name.""" |
| 265 | + with pytest.raises(KeyError): |
| 266 | + self.normal_family.distribution(parametrization_name="invalid_name", mu=0, sigma=1) |
| 267 | + |
| 268 | + def test_missing_parameters(self): |
| 269 | + """Test error for missing required parameters.""" |
| 270 | + with pytest.raises(TypeError): |
| 271 | + self.normal_family.distribution(mu=0) # Missing sigma |
| 272 | + |
| 273 | + def test_invalid_probability_ppf(self): |
| 274 | + """Test PPF with invalid probability values.""" |
| 275 | + dist = self.normal_family(mu=0.0, sigma=1.0) |
| 276 | + ppf_char = GenericCharacteristic[float, float]("ppf") |
| 277 | + |
| 278 | + # Test boundaries |
| 279 | + assert ppf_char(dist, 0.0) == float("-inf") |
| 280 | + assert ppf_char(dist, 1.0) == float("inf") |
| 281 | + |
| 282 | + # Test invalid probabilities |
| 283 | + with pytest.raises(ValueError): |
| 284 | + ppf_char(dist, -0.1) |
| 285 | + with pytest.raises(ValueError): |
| 286 | + ppf_char(dist, 1.1) |
| 287 | + |
| 288 | + |
| 289 | +if __name__ == "__main__": |
| 290 | + pytest.main([__file__, "-v"]) |
0 commit comments