Skip to content

Commit 62f39f1

Browse files
author
TheodorDM
committed
test(families): add test suite for normal distribution
- Add 16 tests covering all distribution characteristics - Test multiple parameterizations and conversions - Verify against scipy.stats.norm for correctness
1 parent 2d024ec commit 62f39f1

File tree

1 file changed

+290
-0
lines changed

1 file changed

+290
-0
lines changed
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
"""
2+
Tests for Normal Distribution Family Configuration
3+
4+
This module tests the functionality of the normal distribution family
5+
defined in config.py, including parameterizations, characteristics,
6+
and sampling.
7+
"""
8+
9+
import math
10+
from typing import cast
11+
12+
import numpy as np
13+
import pytest
14+
from scipy.stats import norm
15+
16+
from pysatl_core.distributions.characteristics import GenericCharacteristic
17+
from pysatl_core.families import ParametricFamilyRegister, configure_family_register
18+
from pysatl_core.families.config import (
19+
ExpParametrization,
20+
MeanPrecParametrization,
21+
MeanVarParametrization,
22+
)
23+
24+
# Import PySATL components
25+
from pysatl_core.families.registry import _reset_families_register_for_tests
26+
from pysatl_core.types import UnivariateContinuous
27+
28+
__author__ = "Fedor Myznikov"
29+
__copyright__ = "Copyright (c) 2025 PySATL project"
30+
__license__ = "SPDX-License-Identifier: MIT"
31+
32+
33+
class TestNormalFamily:
34+
"""Test suite for Normal distribution family."""
35+
36+
def setup_method(self):
37+
"""Setup before each test method."""
38+
_reset_families_register_for_tests()
39+
configure_family_register()
40+
self.normal_family = ParametricFamilyRegister.get("Normal Family")
41+
42+
def test_family_registration(self):
43+
"""Test that normal family is properly registered."""
44+
family = ParametricFamilyRegister.get("Normal Family")
45+
assert family.name == "Normal Family"
46+
47+
# Check parameterizations
48+
expected_parametrizations = {"meanVar", "meanPrec", "exponential"}
49+
assert set(family.parametrization_names) == expected_parametrizations
50+
assert family.base_parametrization_name == "meanVar"
51+
52+
def test_mean_var_parametrization_creation(self):
53+
"""Test creation of distribution with mean-variance parametrization."""
54+
dist = self.normal_family(mu=2.0, sigma=1.5)
55+
56+
assert dist.distr_name == "Normal Family"
57+
assert dist.distribution_type == UnivariateContinuous
58+
59+
# Приводим к конкретному типу параметризации
60+
params = cast(MeanVarParametrization, dist.parameters)
61+
assert params.mu == 2.0
62+
assert params.sigma == 1.5
63+
assert params.name == "meanVar"
64+
65+
def test_mean_prec_parametrization_creation(self):
66+
"""Test creation of distribution with mean-precision parametrization."""
67+
dist = self.normal_family(mu=2.0, tau=0.25, parametrization_name="meanPrec")
68+
69+
# Приводим к конкретному типу параметризации
70+
params = cast(MeanPrecParametrization, dist.parameters)
71+
assert params.mu == 2.0
72+
assert params.tau == 0.25
73+
assert params.name == "meanPrec"
74+
75+
def test_exponential_parametrization_creation(self):
76+
"""Test creation of distribution with exponential parametrization."""
77+
# For N(2, 1.5): a = -1/(2*1.5²) = -0.222..., b = 2/1.5² = 0.888...
78+
dist = self.normal_family(a=-0.222, b=0.888, parametrization_name="exponential")
79+
80+
# Приводим к конкретному типу параметризации
81+
params = cast(ExpParametrization, dist.parameters)
82+
assert params.a == -0.222
83+
assert params.b == 0.888
84+
assert params.name == "exponential"
85+
86+
def test_parametrization_constraints(self):
87+
"""Test parameter constraints validation."""
88+
# Sigma must be positive
89+
with pytest.raises(ValueError, match="sigma > 0"):
90+
self.normal_family(mu=0, sigma=-1.0)
91+
92+
# Tau must be positive
93+
with pytest.raises(ValueError, match="tau > 0"):
94+
self.normal_family(mu=0, tau=-1.0, parametrization_name="meanPrec")
95+
96+
# a must be negative
97+
with pytest.raises(ValueError, match="a < 0"):
98+
self.normal_family(a=1.0, b=0.0, parametrization_name="exponential")
99+
100+
def test_pdf_calculation(self):
101+
"""Test PDF calculation against scipy.stats.norm."""
102+
dist = self.normal_family(mu=2.0, sigma=1.5)
103+
pdf = dist.computation_strategy.query_method("pdf", dist)
104+
105+
# Test points
106+
test_points = [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]
107+
108+
for x in test_points:
109+
# Our implementation
110+
our_pdf = pdf(x)
111+
# Scipy reference
112+
scipy_pdf = norm.pdf(x, loc=2.0, scale=1.5)
113+
114+
assert abs(our_pdf - scipy_pdf) < 1e-10
115+
116+
def test_cdf_calculation(self):
117+
"""Test CDF calculation against scipy.stats.norm."""
118+
dist = self.normal_family(mu=2.0, sigma=1.5)
119+
cdf = dist.computation_strategy.query_method("cdf", dist)
120+
121+
test_points = [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]
122+
123+
for x in test_points:
124+
our_cdf = cdf(x)
125+
scipy_cdf = norm.cdf(x, loc=2.0, scale=1.5)
126+
127+
assert abs(our_cdf - scipy_cdf) < 1e-10
128+
129+
def test_ppf_calculation(self):
130+
"""Test PPF calculation against scipy.stats.norm."""
131+
dist = self.normal_family(mu=2.0, sigma=1.5)
132+
ppf = dist.computation_strategy.query_method("ppf", dist)
133+
134+
test_probabilities = [0.001, 0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99, 0.999]
135+
136+
for p in test_probabilities:
137+
our_ppf = ppf(p)
138+
scipy_ppf = norm.ppf(p, loc=2.0, scale=1.5)
139+
140+
assert abs(our_ppf - scipy_ppf) < 1e-10
141+
142+
def test_characteristic_function(self):
143+
"""Test characteristic function calculation."""
144+
dist = self.normal_family(mu=2.0, sigma=1.5)
145+
char_func = dist.computation_strategy.query_method("char_func", dist)
146+
147+
test_points = [-2.0, -1.0, 0.0, 1.0, 2.0]
148+
149+
for t in test_points:
150+
cf_value = char_func(t)
151+
152+
# Characteristic function of N(μ, σ²) is exp(iμt - ½σ²t²)
153+
expected_real = math.exp(-0.5 * (1.5 * t) ** 2) * math.cos(2.0 * t)
154+
expected_imag = math.exp(-0.5 * (1.5 * t) ** 2) * math.sin(2.0 * t)
155+
156+
assert abs(cf_value.real - expected_real) < 1e-10
157+
assert abs(cf_value.imag - expected_imag) < 1e-10
158+
159+
def test_moments(self):
160+
"""Test moment calculations."""
161+
dist = self.normal_family(mu=2.0, sigma=1.5)
162+
163+
mean_func = dist.computation_strategy.query_method("mean", dist)
164+
var_func = dist.computation_strategy.query_method("var", dist)
165+
skew_func = dist.computation_strategy.query_method("skewness", dist)
166+
kurt_func = dist.computation_strategy.query_method("excess_kurtosis", dist)
167+
168+
our_mean = mean_func(None)
169+
our_var = var_func(None)
170+
our_skew = skew_func(None)
171+
our_kurt = kurt_func(None)
172+
173+
assert abs(our_mean - 2.0) < 1e-10
174+
assert abs(our_var - 2.25) < 1e-10
175+
assert abs(our_skew - 0.0) < 1e-10
176+
assert abs(our_kurt - 0.0) < 1e-10
177+
178+
def test_parametrization_conversions(self):
179+
"""Test conversions between different parameterizations."""
180+
# Create with mean-variance
181+
dist_mv = self.normal_family(mu=2.0, sigma=1.5)
182+
183+
# Convert to base should return same for meanVar
184+
base_params = self.normal_family.to_base(dist_mv.parameters)
185+
base_params = cast(MeanVarParametrization, base_params)
186+
assert base_params.mu == 2.0
187+
assert base_params.sigma == 1.5
188+
189+
# Test meanPrec conversion
190+
dist_mp = self.normal_family(mu=2.0, tau=0.25, parametrization_name="meanPrec")
191+
base_from_mp = self.normal_family.to_base(dist_mp.parameters)
192+
base_from_mp = cast(MeanVarParametrization, base_from_mp)
193+
assert abs(base_from_mp.mu - 2.0) < 1e-10
194+
assert abs(base_from_mp.sigma - 2.0) < 1e-10 # sigma = 1/sqrt(tau) = 1/sqrt(0.25) = 2
195+
196+
# Test exponential conversion
197+
dist_exp = self.normal_family(a=-0.222, b=0.888, parametrization_name="exponential")
198+
base_from_exp = self.normal_family.to_base(dist_exp.parameters)
199+
base_from_exp = cast(MeanVarParametrization, base_from_exp)
200+
# Should be approximately N(2, 1.5)
201+
assert abs(base_from_exp.mu - 2.0) < 0.1
202+
assert abs(base_from_exp.sigma - 1.5) < 0.1
203+
204+
def test_analytical_computations_caching(self):
205+
"""Test that analytical computations are properly cached."""
206+
dist = self.normal_family(mu=0.0, sigma=1.0)
207+
208+
# Access analytical computations multiple times
209+
comp1 = dist.analytical_computations
210+
comp2 = dist.analytical_computations
211+
212+
# Should be the same object (cached)
213+
assert comp1 is comp2
214+
215+
# Should contain expected characteristics
216+
expected_chars = {
217+
"pdf",
218+
"cdf",
219+
"ppf",
220+
"char_func",
221+
"mean",
222+
"var",
223+
"skewness",
224+
"raw_kurtosis",
225+
"excess_kurtosis",
226+
}
227+
assert set(comp1.keys()) == expected_chars
228+
229+
def test_array_input_support(self):
230+
"""Test that characteristics support array inputs."""
231+
dist = self.normal_family(mu=0.0, sigma=1.0)
232+
233+
# Test with numpy array input
234+
x_array = np.array([-2.0, -1.0, 0.0, 1.0, 2.0])
235+
236+
pdf = dist.computation_strategy.query_method("pdf", dist)
237+
cdf = dist.computation_strategy.query_method("cdf", dist)
238+
239+
pdf_array = pdf(x_array)
240+
cdf_array = cdf(x_array)
241+
242+
# Results should be arrays of same shape
243+
assert pdf_array.shape == x_array.shape
244+
assert cdf_array.shape == x_array.shape
245+
246+
# Compare with scipy
247+
scipy_pdf = norm.pdf(x_array)
248+
scipy_cdf = norm.cdf(x_array)
249+
250+
np.testing.assert_array_almost_equal(pdf_array, scipy_pdf, decimal=10)
251+
np.testing.assert_array_almost_equal(cdf_array, scipy_cdf, decimal=10)
252+
253+
254+
class TestNormalFamilyEdgeCases:
255+
"""Test edge cases and error conditions."""
256+
257+
def setup_method(self):
258+
"""Setup before each test method."""
259+
_reset_families_register_for_tests()
260+
configure_family_register()
261+
self.normal_family = ParametricFamilyRegister.get("Normal Family")
262+
263+
def test_invalid_parameterization(self):
264+
"""Test error for invalid parameterization name."""
265+
with pytest.raises(KeyError):
266+
self.normal_family.distribution(parametrization_name="invalid_name", mu=0, sigma=1)
267+
268+
def test_missing_parameters(self):
269+
"""Test error for missing required parameters."""
270+
with pytest.raises(TypeError):
271+
self.normal_family.distribution(mu=0) # Missing sigma
272+
273+
def test_invalid_probability_ppf(self):
274+
"""Test PPF with invalid probability values."""
275+
dist = self.normal_family(mu=0.0, sigma=1.0)
276+
ppf_char = GenericCharacteristic[float, float]("ppf")
277+
278+
# Test boundaries
279+
assert ppf_char(dist, 0.0) == float("-inf")
280+
assert ppf_char(dist, 1.0) == float("inf")
281+
282+
# Test invalid probabilities
283+
with pytest.raises(ValueError):
284+
ppf_char(dist, -0.1)
285+
with pytest.raises(ValueError):
286+
ppf_char(dist, 1.1)
287+
288+
289+
if __name__ == "__main__":
290+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)