-
Notifications
You must be signed in to change notification settings - Fork 1
Families/normal family configuration #58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Implement Normal distribution with meanVar, meanPrec, exponential parameterizations - Add analytical implementations for PDF, CDF, PPF, characteristic function - Implement moment calculations (mean, variance, skewness, kurtosis) - Configure family registry and parametrization system
- Add 16 tests covering all distribution characteristics - Test multiple parameterizations and conversions - Verify against scipy.stats.norm for correctness
LeonidElkin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it's very cool, but there are many minor flaws.
src/pysatl_core/families/config.py
Outdated
| import numpy.typing as npt | ||
| from scipy.special import erf, erfinv | ||
|
|
||
| from pysatl_core.distributions import DefaultSamplingUnivariateStrategy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use full paths inside our library. Prefer pysatl_core.distributions.strategies instead of pysatl_core.distributions. That's because of the __init__.py inside distributions module
src/pysatl_core/families/config.py
Outdated
| def configure_family_register() -> None: | ||
| """ | ||
| Configure and register all distribution families in the global registry. | ||
| This function initializes all parametric families with their respective | ||
| parameterizations, characteristics, and sampling strategies. It should be | ||
| called during application startup to make distributions available. | ||
| """ | ||
| _configure_normal_family() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is required @lru_cache(maxsize=1) decorator and returning ParametricFamilyRegister as well. Like in distributions/registry.py
@lru_cache(maxsize=1)
def distribution_type_register() -> DistributionTypeRegister:
reg = DistributionTypeRegister()
_configure(reg)
return regbut you can do it like
@lru_cache(maxsize=1)
def configure_family_register() -> ParametricFamilyRegister:
_configure_normal_family()
return ParametricFamilyRegister()I do like the style of using ParametricFamilyRegister.register(Smth). It seems more natural in term of singletone object so you have no need in coding reg = DistributionTypeRegister() and pass it through the configuration function. I think I'll do the same thing in the new graph implementation. But we really need caching and returning the register object in such a function
src/pysatl_core/families/config.py
Outdated
| @dataclass | ||
| class MeanVarParametrization(Parametrization): | ||
| """ | ||
| Mean-variance parametrization of normal distribution. | ||
| Parameters | ||
| ---------- | ||
| mu : float | ||
| Mean of the distribution | ||
| sigma : float | ||
| Standard deviation of the distribution | ||
| """ | ||
|
|
||
| mu: float | ||
| sigma: float | ||
|
|
||
| @constraint(description="sigma > 0") | ||
| def check_sigma_positive(self) -> bool: | ||
| """Check that standard deviation is positive.""" | ||
| return self.sigma > 0 | ||
|
|
||
|
|
||
| @dataclass | ||
| class MeanPrecParametrization(Parametrization): | ||
| """ | ||
| Mean-precision parametrization of normal distribution. | ||
| Parameters | ||
| ---------- | ||
| mu : float | ||
| Mean of the distribution | ||
| tau : float | ||
| Precision parameter (inverse variance) | ||
| """ | ||
|
|
||
| mu: float | ||
| tau: float | ||
|
|
||
| @constraint(description="tau > 0") | ||
| def check_tau_positive(self) -> bool: | ||
| """Check that precision parameter is positive.""" | ||
| return self.tau > 0 | ||
|
|
||
| def transform_to_base_parametrization(self) -> Parametrization: | ||
| """ | ||
| Transform to mean-variance parametrization. | ||
| Returns | ||
| ------- | ||
| Parametrization | ||
| Mean-variance parametrization instance | ||
| """ | ||
| sigma = math.sqrt(1 / self.tau) | ||
| return MeanVarParametrization(mu=self.mu, sigma=sigma) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ExpParametrization(Parametrization): | ||
| """ | ||
| Exponential family parametrization of normal distribution. | ||
| Uses the form: y = exp(a*x² + b*x + c) | ||
| Parameters | ||
| ---------- | ||
| a : float | ||
| Quadratic term coefficient in exponential form | ||
| b : float | ||
| Linear term coefficient in exponential form | ||
| """ | ||
|
|
||
| a: float | ||
| b: float | ||
|
|
||
| @property | ||
| def calculate_c(self) -> float: | ||
| """ | ||
| Calculate the normalization constant c. | ||
| Returns | ||
| ------- | ||
| float | ||
| Normalization constant | ||
| """ | ||
| return (self.b**2) / (4 * self.a) - (1 / 2) * math.log(math.pi / (-self.a)) | ||
|
|
||
| @constraint(description="a < 0") | ||
| def check_a_negative(self) -> bool: | ||
| """Check that quadratic term coefficient is negative.""" | ||
| return self.a < 0 | ||
|
|
||
| def transform_to_base_parametrization(self) -> Parametrization: | ||
| """ | ||
| Transform to mean-variance parametrization. | ||
| Returns | ||
| ------- | ||
| Parametrization | ||
| Mean-variance parametrization instance | ||
| """ | ||
| mu = -self.b / (2 * self.a) | ||
| sigma = math.sqrt(-1 / (2 * self.a)) | ||
| return MeanVarParametrization(mu=mu, sigma=sigma) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't you use the @parametrization decorator? You can create these classes after creating Normal object and use @Normal.parametrization syntax. It's more smooth and natural in term of families module. You, by the way, don't have to use @dataclass deco. When you're using @parametrization deco class is beeing transformed into a dataclass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was decided to leave this issue unchanged in private conversation
src/pysatl_core/families/config.py
Outdated
| parametrization(family=Normal, name="meanVar")(MeanVarParametrization) | ||
| parametrization(family=Normal, name="meanPrec")(MeanPrecParametrization) | ||
| parametrization(family=Normal, name="exponential")(ExpParametrization) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As already been said, there is no need in it if you're using @parametrization deco
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was decided to leave this issue unchanged in private conversation
src/pysatl_core/families/config.py
Outdated
| def mean_func(parameters: Parametrization, __: Any = None) -> float: | ||
| """Mean of normal distribution.""" | ||
| parameters = cast(MeanVarParametrization, parameters) | ||
| return parameters.mu | ||
|
|
||
| def var_func(parameters: Parametrization, __: Any = None) -> float: | ||
| """Variance of normal distribution.""" | ||
| parameters = cast(MeanVarParametrization, parameters) | ||
| return parameters.sigma**2 | ||
|
|
||
| def skew_func(_: Parametrization, __: Any = None) -> int: | ||
| """Skewness of normal distribution (always 0).""" | ||
| return 0 | ||
|
|
||
| def raw_kurt_func(_: Parametrization, __: Any = None) -> int: | ||
| """Raw kurtosis of normal distribution (always 3).""" | ||
| return 3 | ||
|
|
||
| def ex_kurt_func(_: Parametrization, __: Any) -> int: | ||
| """Excess kurtosis of normal distribution (always 0).""" | ||
| return 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fact that some arguments are not used is definitely a design problem, but let's use the following naming convention in such cases. If there is one unused argument, we name it _, if there are several, we call them _1, _2 and so on.
| # Test meanPrec conversion | ||
| dist_mp = self.normal_family(mu=2.0, tau=0.25, parametrization_name="meanPrec") | ||
| base_from_mp = self.normal_family.to_base(dist_mp.parameters) | ||
| base_from_mp = cast(MeanVarParametrization, base_from_mp) | ||
| assert abs(base_from_mp.mu - 2.0) < 1e-10 | ||
| assert abs(base_from_mp.sigma - 2.0) < 1e-10 # sigma = 1/sqrt(tau) = 1/sqrt(0.25) = 2 | ||
|
|
||
| # Test exponential conversion | ||
| dist_exp = self.normal_family(a=-0.222, b=0.888, parametrization_name="exponential") | ||
| base_from_exp = self.normal_family.to_base(dist_exp.parameters) | ||
| base_from_exp = cast(MeanVarParametrization, base_from_exp) | ||
| # Should be approximately N(2, 1.5) | ||
| assert abs(base_from_exp.mu - 2.0) < 0.1 | ||
| assert abs(base_from_exp.sigma - 1.5) < 0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually can move not the base parametrizations into a separated test and use parametric tests as well
| def test_analytical_computations_caching(self): | ||
| """Test that analytical computations are properly cached.""" | ||
| dist = self.normal_family(mu=0.0, sigma=1.0) | ||
|
|
||
| # Access analytical computations multiple times | ||
| comp1 = dist.analytical_computations | ||
| comp2 = dist.analytical_computations | ||
|
|
||
| # Should be the same object (cached) | ||
| assert comp1 is comp2 | ||
|
|
||
| # Should contain expected characteristics | ||
| expected_chars = { | ||
| "pdf", | ||
| "cdf", | ||
| "ppf", | ||
| "char_func", | ||
| "mean", | ||
| "var", | ||
| "skewness", | ||
| "raw_kurtosis", | ||
| "excess_kurtosis", | ||
| } | ||
| assert set(comp1.keys()) == expected_chars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Caching is already tested in test_distribution_cache. Leave only this part
def test_analytical_computations_caching(self):
"""Test that analytical computations are properly cached."""
comp = self.normal_family(mu=0.0, sigma=1.0).analytical_computations
# Should contain expected characteristics
expected_chars = {
"pdf",
"cdf",
"ppf",
"char_func",
"mean",
"var",
"skewness",
"raw_kurtosis",
"excess_kurtosis",
}
assert set(comp.keys()) == expected_chars| def test_array_input_support(self): | ||
| """Test that characteristics support array inputs.""" | ||
| dist = self.normal_family(mu=0.0, sigma=1.0) | ||
|
|
||
| # Test with numpy array input | ||
| x_array = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) | ||
|
|
||
| pdf = dist.computation_strategy.query_method("pdf", dist) | ||
| cdf = dist.computation_strategy.query_method("cdf", dist) | ||
|
|
||
| pdf_array = pdf(x_array) | ||
| cdf_array = cdf(x_array) | ||
|
|
||
| # Results should be arrays of same shape | ||
| assert pdf_array.shape == x_array.shape | ||
| assert cdf_array.shape == x_array.shape | ||
|
|
||
| # Compare with scipy | ||
| scipy_pdf = norm.pdf(x_array) | ||
| scipy_cdf = norm.cdf(x_array) | ||
|
|
||
| np.testing.assert_array_almost_equal(pdf_array, scipy_pdf, decimal=10) | ||
| np.testing.assert_array_almost_equal(cdf_array, scipy_cdf, decimal=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't do several characteristics. Either leave only one either do parametric tests where you test every single characteristic
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add to setup something like ``self.normal_dist_example = self.normal_family(mu=2.0, sigma=1.5) ` to use it everywhere where parametrization and its parameters don't matter
…ce and readability increase
7ca3ea7 to
13d09c4
Compare
LeonidElkin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor changes are needed
| from dataclasses import dataclass | ||
| from typing import Any, cast | ||
| from functools import lru_cache | ||
| from typing import TYPE_CHECKING, Any, cast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove import Any from here. You already importing it while type checking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is my fault, though. I gave you the wrong code sample. Double check me ;)
| try: | ||
| from pysatl_core.families.configuration import configure_family_register | ||
|
|
||
| configure_family_register.cache_clear() | ||
| except ImportError: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you using try-except here? Why don't import it in the top of the file?
| @pytest.mark.parametrize( | ||
| "t", | ||
| [ | ||
| (-2.0), | ||
| (-1.0), | ||
| (0.0), | ||
| (1.0), | ||
| (2.0), | ||
| ], | ||
| ) | ||
| def test_characteristic_function(self, t): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try not to use one letter names. No clue what does t mean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And remove redundant parentheses
| @pytest.mark.parametrize( | ||
| "char_func_getter, expected", | ||
| [ | ||
| (lambda d: d.computation_strategy.query_method("mean", d)(None), 2.0), | ||
| (lambda d: d.computation_strategy.query_method("var", d)(None), 2.25), | ||
| (lambda d: d.computation_strategy.query_method("skewness", d)(None), 0.0), | ||
| (lambda d: d.computation_strategy.query_method("excess_kurtosis", d)(None), 0.0), | ||
| ], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as the comment above. But I do know what the d means ;) Allthough, use name like dist or distribution
And don't use None arg. You already have it as the default
| @pytest.mark.parametrize( | ||
| "parametrization_name, params", | ||
| [ | ||
| ("meanVar", {"mu": 2.0, "sigma": 1.5}), | ||
| ("meanPrec", {"mu": 2.0, "tau": 0.25}), | ||
| ("exponential", {"a": -0.222222, "b": 0.888889}), | ||
| ], | ||
| ) | ||
| def test_parametrization_conversions(self, parametrization_name, params): | ||
| """Test conversions between different parameterizations.""" | ||
| dist = self.normal_family(parametrization_name=parametrization_name, **params) | ||
| base_params = cast( | ||
| NormalMeanVarParametrization, self.normal_family.to_base(dist.parameters) | ||
| ) | ||
| tolerance = self.CALCULATION_PRECISION | ||
|
|
||
| def test_moments(self): | ||
| """Test moment calculations.""" | ||
| dist = self.normal_family(mu=2.0, sigma=1.5) | ||
| if parametrization_name == "meanPrec": | ||
| expected_sigma = math.sqrt(1 / params["tau"]) | ||
|
|
||
| mean_func = dist.computation_strategy.query_method("mean", dist) | ||
| var_func = dist.computation_strategy.query_method("var", dist) | ||
| skew_func = dist.computation_strategy.query_method("skewness", dist) | ||
| kurt_func = dist.computation_strategy.query_method("excess_kurtosis", dist) | ||
| elif parametrization_name == "exponential": | ||
| expected_mu = -params["b"] / (2 * params["a"]) | ||
| expected_sigma = math.sqrt(-1 / (2 * params["a"])) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you need to if-else by parameters in parametric tests than you're doing smth wrong. Move expected mu, sigma to parameters. Also don't create useless dist object and do smth like this
self.normal_family.to_base(
self.normal_family.get_parametrization(parametrization_name)(**params)
)And use self.CALCULATION_PRECISION don't make alias via tolerance variable
Implement Normal distribution with meanVar, meanPrec, exponential parameterizations
Add analytical implementations for PDF, CDF, PPF, characteristic function
Implement moment calculations (mean, variance, skewness, kurtosis)
Configure family registry and parametrization system"
Add 16 tests covering all distribution characteristics
Test multiple parameterizations and conversions
Verify against scipy.stats.norm for correctness"