Skip to content

Commit 13d09c4

Browse files
author
TheodorDM
committed
refactor: improve normal family configuration with standards compliance and readability increase
1 parent 62f39f1 commit 13d09c4

File tree

4 files changed

+170
-157
lines changed

4 files changed

+170
-157
lines changed

src/pysatl_core/families/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
__license__ = "SPDX-License-Identifier: MIT"
1313

1414

15-
from pysatl_core.families.config import configure_family_register
15+
from pysatl_core.families.configuration import configure_family_register
1616
from pysatl_core.families.distribution import ParametricFamilyDistribution
1717
from pysatl_core.families.parametric_family import ParametricFamily
1818
from pysatl_core.families.parametrizations import (

src/pysatl_core/families/config.py renamed to src/pysatl_core/families/configuration.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,20 @@
1616

1717
from __future__ import annotations
1818

19+
__author__ = "Fedor Myznikov"
20+
__copyright__ = "Copyright (c) 2025 PySATL project"
21+
__license__ = "SPDX-License-Identifier: MIT"
22+
1923
import math
2024
from dataclasses import dataclass
21-
from typing import Any, cast
25+
from functools import lru_cache
26+
from typing import TYPE_CHECKING, Any, cast
2227

2328
import numpy as np
2429
import numpy.typing as npt
2530
from scipy.special import erf, erfinv
2631

27-
from pysatl_core.distributions import DefaultSamplingUnivariateStrategy
32+
from pysatl_core.distributions.strategies import DefaultSamplingUnivariateStrategy
2833
from pysatl_core.families.parametric_family import ParametricFamily
2934
from pysatl_core.families.parametrizations import (
3035
Parametrization,
@@ -34,20 +39,8 @@
3439
from pysatl_core.families.registry import ParametricFamilyRegister
3540
from pysatl_core.types import UnivariateContinuous
3641

37-
__author__ = "Fedor Myznikov"
38-
__copyright__ = "Copyright (c) 2025 PySATL project"
39-
__license__ = "SPDX-License-Identifier: MIT"
40-
41-
42-
def configure_family_register() -> None:
43-
"""
44-
Configure and register all distribution families in the global registry.
45-
46-
This function initializes all parametric families with their respective
47-
parameterizations, characteristics, and sampling strategies. It should be
48-
called during application startup to make distributions available.
49-
"""
50-
_configure_normal_family()
42+
if TYPE_CHECKING:
43+
from typing import Any
5144

5245

5346
PDF = "pdf"
@@ -61,8 +54,26 @@ def configure_family_register() -> None:
6154
EXKURT = "excess_kurtosis"
6255

6356

57+
@lru_cache(maxsize=1)
58+
def configure_family_register() -> ParametricFamilyRegister:
59+
"""
60+
Configure and register all distribution families in the global registry.
61+
62+
This function initializes all parametric families with their respective
63+
parameterizations, characteristics, and sampling strategies. It should be
64+
called during application startup to make distributions available.
65+
66+
Returns
67+
-------
68+
ParametricFamilyRegister
69+
The global registry of parametric families.
70+
"""
71+
_configure_normal_family()
72+
return ParametricFamilyRegister()
73+
74+
6475
@dataclass
65-
class MeanVarParametrization(Parametrization):
76+
class NormalMeanVarParametrization(Parametrization):
6677
"""
6778
Mean-variance parametrization of normal distribution.
6879
@@ -84,7 +95,7 @@ def check_sigma_positive(self) -> bool:
8495

8596

8697
@dataclass
87-
class MeanPrecParametrization(Parametrization):
98+
class NormalMeanPrecParametrization(Parametrization):
8899
"""
89100
Mean-precision parametrization of normal distribution.
90101
@@ -114,11 +125,11 @@ def transform_to_base_parametrization(self) -> Parametrization:
114125
Mean-variance parametrization instance
115126
"""
116127
sigma = math.sqrt(1 / self.tau)
117-
return MeanVarParametrization(mu=self.mu, sigma=sigma)
128+
return NormalMeanVarParametrization(mu=self.mu, sigma=sigma)
118129

119130

120131
@dataclass
121-
class ExpParametrization(Parametrization):
132+
class NormalExpParametrization(Parametrization):
122133
"""
123134
Exponential family parametrization of normal distribution.
124135
Uses the form: y = exp(a*x² + b*x + c)
@@ -161,7 +172,7 @@ def transform_to_base_parametrization(self) -> Parametrization:
161172
"""
162173
mu = -self.b / (2 * self.a)
163174
sigma = math.sqrt(-1 / (2 * self.a))
164-
return MeanVarParametrization(mu=mu, sigma=sigma)
175+
return NormalMeanVarParametrization(mu=mu, sigma=sigma)
165176

166177

167178
def _configure_normal_family() -> None:
@@ -199,7 +210,7 @@ def normal_pdf(
199210
npt.NDArray[np.float64]
200211
Probability density values at points x
201212
"""
202-
parameters = cast(MeanVarParametrization, parameters)
213+
parameters = cast(NormalMeanVarParametrization, parameters)
203214

204215
sigma = parameters.sigma
205216
mu = parameters.mu
@@ -229,7 +240,7 @@ def normal_cdf(
229240
npt.NDArray[np.float64]
230241
Probabilities P(X ≤ x) for each point x
231242
"""
232-
parameters = cast(MeanVarParametrization, parameters)
243+
parameters = cast(NormalMeanVarParametrization, parameters)
233244

234245
z = (x - parameters.mu) / (parameters.sigma * np.sqrt(2))
235246
return cast(npt.NDArray[np.float64], 0.5 * (1 + erf(z)))
@@ -262,7 +273,7 @@ def normal_ppf(
262273
if np.any((p < 0) | (p > 1)):
263274
raise ValueError("Probability must be in [0, 1]")
264275

265-
parameters = cast(MeanVarParametrization, parameters)
276+
parameters = cast(NormalMeanVarParametrization, parameters)
266277

267278
return cast(
268279
npt.NDArray[np.float64],
@@ -289,31 +300,31 @@ def normal_char_func(
289300
npt.NDArray[np.complex128]
290301
Characteristic function values at points x
291302
"""
292-
parameters = cast(MeanVarParametrization, parameters)
303+
parameters = cast(NormalMeanVarParametrization, parameters)
293304

294305
sigma = parameters.sigma
295306
mu = parameters.mu
296307
return cast(npt.NDArray[np.complex128], np.exp(1j * mu * x - 0.5 * (sigma * x) ** 2))
297308

298-
def mean_func(parameters: Parametrization, __: Any = None) -> float:
309+
def mean_func(parameters: Parametrization, _: Any = None) -> float:
299310
"""Mean of normal distribution."""
300-
parameters = cast(MeanVarParametrization, parameters)
311+
parameters = cast(NormalMeanVarParametrization, parameters)
301312
return parameters.mu
302313

303-
def var_func(parameters: Parametrization, __: Any = None) -> float:
314+
def var_func(parameters: Parametrization, _: Any = None) -> float:
304315
"""Variance of normal distribution."""
305-
parameters = cast(MeanVarParametrization, parameters)
316+
parameters = cast(NormalMeanVarParametrization, parameters)
306317
return parameters.sigma**2
307318

308-
def skew_func(_: Parametrization, __: Any = None) -> int:
319+
def skew_func(_1: Parametrization, _2: Any = None) -> int:
309320
"""Skewness of normal distribution (always 0)."""
310321
return 0
311322

312-
def raw_kurt_func(_: Parametrization, __: Any = None) -> int:
323+
def raw_kurt_func(_1: Parametrization, _2: Any = None) -> int:
313324
"""Raw kurtosis of normal distribution (always 3)."""
314325
return 3
315326

316-
def ex_kurt_func(_: Parametrization, __: Any) -> int:
327+
def ex_kurt_func(_1: Parametrization, _2: Any = None) -> int:
317328
"""Excess kurtosis of normal distribution (always 0)."""
318329
return 0
319330

@@ -336,8 +347,8 @@ def ex_kurt_func(_: Parametrization, __: Any) -> int:
336347
)
337348
Normal.__doc__ = NORMAL_DOC
338349

339-
parametrization(family=Normal, name="meanVar")(MeanVarParametrization)
340-
parametrization(family=Normal, name="meanPrec")(MeanPrecParametrization)
341-
parametrization(family=Normal, name="exponential")(ExpParametrization)
350+
parametrization(family=Normal, name="meanVar")(NormalMeanVarParametrization)
351+
parametrization(family=Normal, name="meanPrec")(NormalMeanPrecParametrization)
352+
parametrization(family=Normal, name="exponential")(NormalExpParametrization)
342353

343354
ParametricFamilyRegister.register(Normal)

src/pysatl_core/families/registry.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from __future__ import annotations
1010

11-
__author__ = "Leonid Elkin, Mikhail, Mikhailov"
11+
__author__ = "Leonid Elkin, Mikhail, Mikhailov, Fedor Myznikov"
1212
__copyright__ = "Copyright (c) 2025 PySATL project"
1313
__license__ = "SPDX-License-Identifier: MIT"
1414

@@ -93,7 +93,26 @@ def register(cls, family: ParametricFamily) -> None:
9393
raise ValueError(f"Family {family.name} already found in register")
9494
self._registered_families[family.name] = family
9595

96+
@classmethod
97+
def clear(cls) -> None:
98+
"""
99+
Clear the registry (for testing purposes).
100+
101+
This method removes all registered families and resets the singleton instance.
102+
It should only be used in tests.
103+
"""
104+
if cls._instance is not None:
105+
cls._instance._registered_families.clear()
106+
cls._instance = None
107+
96108

97109
def _reset_families_register_for_tests() -> None:
98110
"""Reset the cached distribution type register (test helper)."""
99-
ParametricFamilyRegister._instance = None
111+
ParametricFamilyRegister.clear()
112+
113+
try:
114+
from pysatl_core.families.configuration import configure_family_register
115+
116+
configure_family_register.cache_clear()
117+
except ImportError:
118+
pass

0 commit comments

Comments
 (0)