diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py index 8ac209d..71937e2 100644 --- a/src/pysatl_core/families/__init__.py +++ b/src/pysatl_core/families/__init__.py @@ -7,11 +7,12 @@ between different parameter formats. """ -__author__ = "Leonid Elkin, Mikhail, Mikhailov" +__author__ = "Leonid Elkin, Mikhail, Mikhailov, Fedor Myznikov" __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" +from pysatl_core.families.configuration import configure_families_register from pysatl_core.families.distribution import ParametricFamilyDistribution from pysatl_core.families.parametric_family import ParametricFamily from pysatl_core.families.parametrizations import ( @@ -30,4 +31,5 @@ "ParametricFamilyDistribution", "constraint", "parametrization", + "configure_families_register", ] diff --git a/src/pysatl_core/families/configuration.py b/src/pysatl_core/families/configuration.py new file mode 100644 index 0000000..5ccec9d --- /dev/null +++ b/src/pysatl_core/families/configuration.py @@ -0,0 +1,359 @@ +""" +Distribution Families Configuration +==================================== + +This module defines and configures parametric distribution families for the PySATL library: + +- :class:`Normal Family` — Gaussian distribution with multiple parameterizations. + +Notes +----- +- All families are registered in the global ParametricFamilyRegister. +- Each family supports multiple parameterizations with automatic conversions. +- Analytical implementations are provided where available, with fallbacks to numerical methods. +- Families are designed to be extensible with additional characteristics and parameterizations. +""" + +from __future__ import annotations + +__author__ = "Fedor Myznikov" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +import math +from dataclasses import dataclass +from functools import lru_cache +from typing import TYPE_CHECKING, cast + +import numpy as np +import numpy.typing as npt +from scipy.special import erf, erfinv + +from pysatl_core.distributions.strategies import DefaultSamplingUnivariateStrategy +from pysatl_core.families.parametric_family import ParametricFamily +from pysatl_core.families.parametrizations import ( + Parametrization, + constraint, + parametrization, +) +from pysatl_core.families.registry import ParametricFamilyRegister +from pysatl_core.types import UnivariateContinuous + +if TYPE_CHECKING: + from typing import Any + + +PDF = "pdf" +CDF = "cdf" +PPF = "ppf" +CF = "char_func" +MEAN = "mean" +VAR = "var" +SKEW = "skewness" +RAWKURT = "raw_kurtosis" +EXKURT = "excess_kurtosis" + + +@lru_cache(maxsize=1) +def configure_families_register() -> ParametricFamilyRegister: + """ + Configure and register all distribution families in the global registry. + + This function initializes all parametric families with their respective + parameterizations, characteristics, and sampling strategies. It should be + called during application startup to make distributions available. + + Returns + ------- + ParametricFamilyRegister + The global registry of parametric families. + """ + _configure_normal_family() + return ParametricFamilyRegister() + + +@dataclass +class NormalMeanVarParametrization(Parametrization): + """ + Mean-variance parametrization of normal distribution. + + Parameters + ---------- + mu : float + Mean of the distribution + sigma : float + Standard deviation of the distribution + """ + + mu: float + sigma: float + + @constraint(description="sigma > 0") + def check_sigma_positive(self) -> bool: + """Check that standard deviation is positive.""" + return self.sigma > 0 + + +@dataclass +class NormalMeanPrecParametrization(Parametrization): + """ + Mean-precision parametrization of normal distribution. + + Parameters + ---------- + mu : float + Mean of the distribution + tau : float + Precision parameter (inverse variance) + """ + + mu: float + tau: float + + @constraint(description="tau > 0") + def check_tau_positive(self) -> bool: + """Check that precision parameter is positive.""" + return self.tau > 0 + + def transform_to_base_parametrization(self) -> Parametrization: + """ + Transform to mean-variance parametrization. + + Returns + ------- + Parametrization + Mean-variance parametrization instance + """ + sigma = math.sqrt(1 / self.tau) + return NormalMeanVarParametrization(mu=self.mu, sigma=sigma) + + +@dataclass +class NormalExpParametrization(Parametrization): + """ + Exponential family parametrization of normal distribution. + Uses the form: y = exp(a*x² + b*x + c) + + Parameters + ---------- + a : float + Quadratic term coefficient in exponential form + b : float + Linear term coefficient in exponential form + """ + + a: float + b: float + + @property + def calculate_c(self) -> float: + """ + Calculate the normalization constant c. + + Returns + ------- + float + Normalization constant + """ + return (self.b**2) / (4 * self.a) - (1 / 2) * math.log(math.pi / (-self.a)) + + @constraint(description="a < 0") + def check_a_negative(self) -> bool: + """Check that quadratic term coefficient is negative.""" + return self.a < 0 + + def transform_to_base_parametrization(self) -> Parametrization: + """ + Transform to mean-variance parametrization. + Returns + ------- + Parametrization + Mean-variance parametrization instance + """ + mu = -self.b / (2 * self.a) + sigma = math.sqrt(-1 / (2 * self.a)) + return NormalMeanVarParametrization(mu=mu, sigma=sigma) + + +def _configure_normal_family() -> None: + NORMAL_DOC = """ + Normal (Gaussian) distribution. + + The normal distribution is a continuous probability distribution characterized + by its bell-shaped curve. It is symmetric about its mean and is defined by + two parameters: mean (μ) and standard deviation (σ). + + Probability density function: + f(x) = 1/(σ√(2π)) * exp(-(x-μ)²/(2σ²)) + + The normal distribution is widely used in statistics, natural sciences, + and social sciences as a simple model for complex random phenomena. + """ + + def normal_pdf( + parameters: Parametrization, x: npt.NDArray[np.float64] + ) -> npt.NDArray[np.float64]: + """ + Probability density function for normal distribution. + + Parameters + ---------- + parameters : Parametrization () + Distribution parameters object with fields: + - mu: float (mean) + - sigma: float (standard deviation) + x : npt.NDArray[np.float64] + Points at which to evaluate the probability density function + + Returns + ------- + npt.NDArray[np.float64] + Probability density values at points x + """ + parameters = cast(NormalMeanVarParametrization, parameters) + + sigma = parameters.sigma + mu = parameters.mu + + coefficient = 1.0 / (sigma * np.sqrt(2 * np.pi)) + exponent = -((x - mu) ** 2) / (2 * sigma**2) + + return cast(npt.NDArray[np.float64], coefficient * np.exp(exponent)) + + def normal_cdf( + parameters: Parametrization, x: npt.NDArray[np.float64] + ) -> npt.NDArray[np.float64]: + """ + Cumulative distribution function for normal distribution. + + Parameters + ---------- + parameters : Parametrization + Distribution parameters object with fields: + - mu: float (mean) + - sigma: float (standard deviation) + x : npt.NDArray[np.float64] + Points at which to evaluate the cumulative distribution function + + Returns + ------- + npt.NDArray[np.float64] + Probabilities P(X ≤ x) for each point x + """ + parameters = cast(NormalMeanVarParametrization, parameters) + + z = (x - parameters.mu) / (parameters.sigma * np.sqrt(2)) + return cast(npt.NDArray[np.float64], 0.5 * (1 + erf(z))) + + def normal_ppf( + parameters: Parametrization, p: npt.NDArray[np.float64] + ) -> npt.NDArray[np.float64]: + """ + Percent point function (inverse CDF) for normal distribution. + + Parameters + ---------- + parameters : Parametrization + Distribution parameters object with fields: + - mu: float (mean) + - sigma: float (standard deviation) + p : npt.NDArray[np.float64] + Probability from [0, 1] + + Returns + ------- + npt.NDArray[np.float64] + Quantiles corresponding to probabilities p + + Raises + ------ + ValueError + If probability is outside [0, 1] + """ + if np.any((p < 0) | (p > 1)): + raise ValueError("Probability must be in [0, 1]") + + parameters = cast(NormalMeanVarParametrization, parameters) + + return cast( + npt.NDArray[np.float64], + parameters.mu + parameters.sigma * np.sqrt(2) * erfinv(2 * p - 1), + ) + + def normal_char_func( + parameters: Parametrization, t: npt.NDArray[np.float64] + ) -> npt.NDArray[np.complex128]: + """ + Characteristic function of normal distribution. + + Parameters + ---------- + parameters : Parametrization + Distribution parameters object with fields: + - mu: float (mean) + - sigma: float (standard deviation) + x : npt.NDArray[np.float64] + Points at which to evaluate the characteristic function + + Returns + ------- + npt.NDArray[np.complex128] + Characteristic function values at points x + """ + parameters = cast(NormalMeanVarParametrization, parameters) + + sigma = parameters.sigma + mu = parameters.mu + return cast(npt.NDArray[np.complex128], np.exp(1j * mu * t - 0.5 * (sigma**2) * (t**2))) + + def mean_func(parameters: Parametrization, _: Any) -> float: + """Mean of normal distribution.""" + parameters = cast(NormalMeanVarParametrization, parameters) + return parameters.mu + + def var_func(parameters: Parametrization, _: Any) -> float: + """Variance of normal distribution.""" + parameters = cast(NormalMeanVarParametrization, parameters) + return parameters.sigma**2 + + def skew_func(_1: Parametrization, _2: Any) -> int: + """Skewness of normal distribution (always 0).""" + return 0 + + def raw_kurt_func(_1: Parametrization, _2: Any) -> int: + """Raw kurtosis of normal distribution (always 3).""" + return 3 + + def ex_kurt_func(_1: Parametrization, _2: Any) -> int: + """Excess kurtosis of normal distribution (always 0).""" + return 0 + + Normal = ParametricFamily( + name="Normal Family", + distr_type=UnivariateContinuous, + distr_parametrizations=["meanVar", "meanPrec", "exponential"], + distr_characteristics={ + PDF: normal_pdf, + CDF: normal_cdf, + PPF: normal_ppf, + CF: normal_char_func, + MEAN: mean_func, + VAR: var_func, + SKEW: skew_func, + RAWKURT: raw_kurt_func, + EXKURT: ex_kurt_func, + }, + sampling_strategy=DefaultSamplingUnivariateStrategy(), + ) + Normal.__doc__ = NORMAL_DOC + + parametrization(family=Normal, name="meanVar")(NormalMeanVarParametrization) + parametrization(family=Normal, name="meanPrec")(NormalMeanPrecParametrization) + parametrization(family=Normal, name="exponential")(NormalExpParametrization) + + ParametricFamilyRegister.register(Normal) + + +def reset_families_register() -> None: + configure_families_register.cache_clear() + ParametricFamilyRegister._reset() diff --git a/src/pysatl_core/families/parametric_family.py b/src/pysatl_core/families/parametric_family.py index 0cdcba8..e226bb3 100644 --- a/src/pysatl_core/families/parametric_family.py +++ b/src/pysatl_core/families/parametric_family.py @@ -9,7 +9,7 @@ from __future__ import annotations -__author__ = "Leonid Elkin, Mikhail, Mikhailov" +__author__ = "Leonid Elkin, Mikhail, Mikhailov, Fedor Myznikov" __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" @@ -322,8 +322,8 @@ def distribution( parametrization_class = self._parametrizations[parametrization_name] parameters = parametrization_class(**parameters_values) - base_parameters = self.to_base(parameters) parameters.validate() + base_parameters = self.to_base(parameters) distribution_type = self._distr_type(base_parameters) return ParametricFamilyDistribution(self.name, distribution_type, parameters) diff --git a/src/pysatl_core/families/registry.py b/src/pysatl_core/families/registry.py index ba12dc6..ddf9f9b 100644 --- a/src/pysatl_core/families/registry.py +++ b/src/pysatl_core/families/registry.py @@ -8,7 +8,7 @@ from __future__ import annotations -__author__ = "Leonid Elkin, Mikhail, Mikhailov" +__author__ = "Leonid Elkin, Mikhail, Mikhailov, Fedor Myznikov" __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" @@ -93,7 +93,14 @@ def register(cls, family: ParametricFamily) -> None: raise ValueError(f"Family {family.name} already found in register") self._registered_families[family.name] = family + @classmethod + def _reset(cls) -> None: + """ + Clear the registry (for testing purposes). -def _reset_families_register_for_tests() -> None: - """Reset the cached distribution type register (test helper).""" - ParametricFamilyRegister._instance = None + This method removes all registered families and resets the singleton instance. + It should only be used in tests. + """ + if cls._instance is not None: + cls._instance._registered_families.clear() + cls._instance = None diff --git a/tests/conftest.py b/tests/conftest.py index f8c3e4c..b68e839 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import pytest from pysatl_core.distributions.registry import _reset_distribution_type_register_for_tests -from pysatl_core.families.registry import _reset_families_register_for_tests +from pysatl_core.families.configuration import reset_families_register pytest.importorskip("scipy") @@ -14,5 +14,5 @@ @pytest.fixture(autouse=True) def _fresh_registries() -> Generator[None, Any, None]: _reset_distribution_type_register_for_tests() - _reset_families_register_for_tests() + reset_families_register() yield diff --git a/tests/unit/families/test_configuration.py b/tests/unit/families/test_configuration.py new file mode 100644 index 0000000..77d857b --- /dev/null +++ b/tests/unit/families/test_configuration.py @@ -0,0 +1,267 @@ +""" +Tests for Normal Distribution Family Configuration + +This module tests the functionality of the normal distribution family +defined in config.py, including parameterizations, characteristics, +and sampling. +""" + +__author__ = "Fedor Myznikov" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +import math +from typing import cast + +import numpy as np +import pytest +from scipy.stats import norm + +from pysatl_core.distributions.characteristics import GenericCharacteristic +from pysatl_core.families.configuration import ( + NormalExpParametrization, + NormalMeanPrecParametrization, + NormalMeanVarParametrization, + configure_families_register, +) +from pysatl_core.families.registry import ParametricFamilyRegister +from pysatl_core.types import UnivariateContinuous + + +class TestNormalFamily: + """Test suite for Normal distribution family.""" + + # Precision for floating point comparisons + CALCULATION_PRECISION = 1e-10 + + def setup_method(self): + """Setup before each test method.""" + registry = configure_families_register() + self.normal_family = registry.get("Normal Family") + self.normal_dist_example = self.normal_family(mu=2.0, sigma=1.5) + + def test_family_registration(self): + """Test that normal family is properly registered.""" + family = ParametricFamilyRegister.get("Normal Family") + assert family.name == "Normal Family" + + # Check parameterizations + expected_parametrizations = {"meanVar", "meanPrec", "exponential"} + assert set(family.parametrization_names) == expected_parametrizations + assert family.base_parametrization_name == "meanVar" + + def test_mean_var_parametrization_creation(self): + """Test creation of distribution with mean-variance parametrization.""" + dist = self.normal_family(mu=2.0, sigma=1.5) + + assert dist.distr_name == "Normal Family" + assert dist.distribution_type == UnivariateContinuous + + params = cast(NormalMeanVarParametrization, dist.parameters) + assert params.mu == 2.0 + assert params.sigma == 1.5 + assert params.name == "meanVar" + + def test_mean_prec_parametrization_creation(self): + """Test creation of distribution with mean-precision parametrization.""" + dist = self.normal_family(mu=2.0, tau=0.25, parametrization_name="meanPrec") + + params = cast(NormalMeanPrecParametrization, dist.parameters) + assert params.mu == 2.0 + assert params.tau == 0.25 + assert params.name == "meanPrec" + + def test_exponential_parametrization_creation(self): + """Test creation of distribution with exponential parametrization.""" + # For N(2, 1.5): a = -1/(2*1.5²) = -0.222..., b = 2/1.5² = 0.888... + dist = self.normal_family(a=-0.222, b=0.888, parametrization_name="exponential") + + params = cast(NormalExpParametrization, dist.parameters) + assert params.a == -0.222 + assert params.b == 0.888 + assert params.name == "exponential" + + def test_parametrization_constraints(self): + """Test parameter constraints validation.""" + # Sigma must be positive + with pytest.raises(ValueError, match="sigma > 0"): + self.normal_family(mu=0, sigma=-1.0) + + # Tau must be positive + with pytest.raises(ValueError, match="tau > 0"): + self.normal_family(mu=0, tau=-1.0, parametrization_name="meanPrec") + + # a must be negative + with pytest.raises(ValueError, match="a < 0"): + self.normal_family(a=1.0, b=0.0, parametrization_name="exponential") + + def test_pdf_calculation(self): + """Test PDF calculation against scipy.stats.norm.""" + pdf = self.normal_dist_example.computation_strategy.query_method( + "pdf", self.normal_dist_example + ) + test_points = [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0] + + for x in test_points: + # Our implementation + our_pdf = pdf(x) + # Scipy reference + scipy_pdf = norm.pdf(x, loc=2.0, scale=1.5) + + assert abs(our_pdf - scipy_pdf) < self.CALCULATION_PRECISION + + def test_cdf_calculation(self): + """Test CDF calculation against scipy.stats.norm.""" + cdf = self.normal_dist_example.computation_strategy.query_method( + "cdf", self.normal_dist_example + ) + test_points = [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0] + + for x in test_points: + our_cdf = cdf(x) + scipy_cdf = norm.cdf(x, loc=2.0, scale=1.5) + + assert abs(our_cdf - scipy_cdf) < self.CALCULATION_PRECISION + + def test_ppf_calculation(self): + """Test PPF calculation against scipy.stats.norm.""" + ppf = self.normal_dist_example.computation_strategy.query_method( + "ppf", self.normal_dist_example + ) + test_probabilities = [0.001, 0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99, 0.999] + + for p in test_probabilities: + our_ppf = ppf(p) + scipy_ppf = norm.ppf(p, loc=2.0, scale=1.5) + + assert abs(our_ppf - scipy_ppf) < self.CALCULATION_PRECISION + + @pytest.mark.parametrize( + "char_func_arg", + [ + -2.0, + -1.0, + 0.0, + 1.0, + 2.0, + ], + ) + def test_characteristic_function(self, char_func_arg): + """Test characteristic function calculation at specific points.""" + char_func = self.normal_dist_example.computation_strategy.query_method( + "char_func", self.normal_dist_example + ) + cf_value = char_func(char_func_arg) + + expected_real = math.exp(-0.5 * (1.5 * char_func_arg) ** 2) * math.cos(2.0 * char_func_arg) + expected_imag = math.exp(-0.5 * (1.5 * char_func_arg) ** 2) * math.sin(2.0 * char_func_arg) + + assert abs(cf_value.real - expected_real) < self.CALCULATION_PRECISION + assert abs(cf_value.imag - expected_imag) < self.CALCULATION_PRECISION + + @pytest.mark.parametrize( + "char_func_getter, expected", + [ + (lambda distr: distr.computation_strategy.query_method("mean", distr)(None), 2.0), + (lambda distr: distr.computation_strategy.query_method("var", distr)(None), 2.25), + (lambda distr: distr.computation_strategy.query_method("skewness", distr)(None), 0.0), + ( + lambda distr: distr.computation_strategy.query_method("excess_kurtosis", distr)( + None + ), + 0.0, + ), + ], + ) + def test_moments(self, char_func_getter, expected): + """Test moment calculations using parameterized tests.""" + actual = char_func_getter(self.normal_dist_example) + assert abs(actual - expected) < self.CALCULATION_PRECISION + + @pytest.mark.parametrize( + "parametrization_name, params, expected_mu, expected_sigma", + [ + ("meanVar", {"mu": 2.0, "sigma": 1.5}, 2.0, 1.5), + ("meanPrec", {"mu": 2.0, "tau": 0.25}, 2.0, math.sqrt(1 / 0.25)), + ("exponential", {"a": -1 / (2 * 1.5**2), "b": 2 / (1.5**2)}, 2.0, 1.5), + ], + ) + def test_parametrization_conversions( + self, parametrization_name, params, expected_mu, expected_sigma + ): + """Test conversions between different parameterizations.""" + base_params = cast( + NormalMeanVarParametrization, + self.normal_family.to_base( + self.normal_family.get_parametrization(parametrization_name)(**params) + ), + ) + + assert abs(base_params.mu - expected_mu) < self.CALCULATION_PRECISION + assert abs(base_params.sigma - expected_sigma) < self.CALCULATION_PRECISION + + def test_analytical_computations_caching(self): + """Test that analytical computations are properly cached.""" + comp = self.normal_family(mu=0.0, sigma=1.0).analytical_computations + + expected_chars = { + "pdf", + "cdf", + "ppf", + "char_func", + "mean", + "var", + "skewness", + "raw_kurtosis", + "excess_kurtosis", + } + assert set(comp.keys()) == expected_chars + + def test_array_input_support(self): + """Test that PDF supports array inputs.""" + dist = self.normal_family(mu=0.0, sigma=1.0) + x_array = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) + + pdf = dist.computation_strategy.query_method("pdf", dist) + pdf_array = pdf(x_array) + + assert pdf_array.shape == x_array.shape + scipy_pdf = norm.pdf(x_array, loc=0.0, scale=1.0) + + np.testing.assert_array_almost_equal( + pdf_array, scipy_pdf, decimal=int(-math.log10(self.CALCULATION_PRECISION)) + ) + + +class TestNormalFamilyEdgeCases: + """Test edge cases and error conditions.""" + + def setup_method(self): + """Setup before each test method.""" + configure_families_register() + self.normal_family = ParametricFamilyRegister.get("Normal Family") + + def test_invalid_parameterization(self): + """Test error for invalid parameterization name.""" + with pytest.raises(KeyError): + self.normal_family.distribution(parametrization_name="invalid_name", mu=0, sigma=1) + + def test_missing_parameters(self): + """Test error for missing required parameters.""" + with pytest.raises(TypeError): + self.normal_family.distribution(mu=0) # Missing sigma + + def test_invalid_probability_ppf(self): + """Test PPF with invalid probability values.""" + dist = self.normal_family(mu=0.0, sigma=1.0) + ppf_char = GenericCharacteristic[float, float]("ppf") + + # Test boundaries + assert ppf_char(dist, 0.0) == float("-inf") + assert ppf_char(dist, 1.0) == float("inf") + + # Test invalid probabilities + with pytest.raises(ValueError): + ppf_char(dist, -0.1) + with pytest.raises(ValueError): + ppf_char(dist, 1.1)