1616
1717from __future__ import annotations
1818
19+ __author__ = "Fedor Myznikov"
20+ __copyright__ = "Copyright (c) 2025 PySATL project"
21+ __license__ = "SPDX-License-Identifier: MIT"
22+
1923import math
2024from dataclasses import dataclass
21- from typing import Any , cast
25+ from functools import lru_cache
26+ from typing import TYPE_CHECKING , Any , cast
2227
2328import numpy as np
2429import numpy .typing as npt
2530from scipy .special import erf , erfinv
2631
27- from pysatl_core .distributions import DefaultSamplingUnivariateStrategy
32+ from pysatl_core .distributions . strategies import DefaultSamplingUnivariateStrategy
2833from pysatl_core .families .parametric_family import ParametricFamily
2934from pysatl_core .families .parametrizations import (
3035 Parametrization ,
3439from pysatl_core .families .registry import ParametricFamilyRegister
3540from pysatl_core .types import UnivariateContinuous
3641
37- __author__ = "Fedor Myznikov"
38- __copyright__ = "Copyright (c) 2025 PySATL project"
39- __license__ = "SPDX-License-Identifier: MIT"
40-
41-
42- def configure_family_register () -> None :
43- """
44- Configure and register all distribution families in the global registry.
45-
46- This function initializes all parametric families with their respective
47- parameterizations, characteristics, and sampling strategies. It should be
48- called during application startup to make distributions available.
49- """
50- _configure_normal_family ()
42+ if TYPE_CHECKING :
43+ from typing import Any
5144
5245
5346PDF = "pdf"
@@ -61,8 +54,26 @@ def configure_family_register() -> None:
6154EXKURT = "excess_kurtosis"
6255
6356
57+ @lru_cache (maxsize = 1 )
58+ def configure_family_register () -> ParametricFamilyRegister :
59+ """
60+ Configure and register all distribution families in the global registry.
61+
62+ This function initializes all parametric families with their respective
63+ parameterizations, characteristics, and sampling strategies. It should be
64+ called during application startup to make distributions available.
65+
66+ Returns
67+ -------
68+ ParametricFamilyRegister
69+ The global registry of parametric families.
70+ """
71+ _configure_normal_family ()
72+ return ParametricFamilyRegister ()
73+
74+
6475@dataclass
65- class MeanVarParametrization (Parametrization ):
76+ class NormalMeanVarParametrization (Parametrization ):
6677 """
6778 Mean-variance parametrization of normal distribution.
6879
@@ -84,7 +95,7 @@ def check_sigma_positive(self) -> bool:
8495
8596
8697@dataclass
87- class MeanPrecParametrization (Parametrization ):
98+ class NormalMeanPrecParametrization (Parametrization ):
8899 """
89100 Mean-precision parametrization of normal distribution.
90101
@@ -114,11 +125,11 @@ def transform_to_base_parametrization(self) -> Parametrization:
114125 Mean-variance parametrization instance
115126 """
116127 sigma = math .sqrt (1 / self .tau )
117- return MeanVarParametrization (mu = self .mu , sigma = sigma )
128+ return NormalMeanVarParametrization (mu = self .mu , sigma = sigma )
118129
119130
120131@dataclass
121- class ExpParametrization (Parametrization ):
132+ class NormalExpParametrization (Parametrization ):
122133 """
123134 Exponential family parametrization of normal distribution.
124135 Uses the form: y = exp(a*x² + b*x + c)
@@ -161,7 +172,7 @@ def transform_to_base_parametrization(self) -> Parametrization:
161172 """
162173 mu = - self .b / (2 * self .a )
163174 sigma = math .sqrt (- 1 / (2 * self .a ))
164- return MeanVarParametrization (mu = mu , sigma = sigma )
175+ return NormalMeanVarParametrization (mu = mu , sigma = sigma )
165176
166177
167178def _configure_normal_family () -> None :
@@ -199,7 +210,7 @@ def normal_pdf(
199210 npt.NDArray[np.float64]
200211 Probability density values at points x
201212 """
202- parameters = cast (MeanVarParametrization , parameters )
213+ parameters = cast (NormalMeanVarParametrization , parameters )
203214
204215 sigma = parameters .sigma
205216 mu = parameters .mu
@@ -229,7 +240,7 @@ def normal_cdf(
229240 npt.NDArray[np.float64]
230241 Probabilities P(X ≤ x) for each point x
231242 """
232- parameters = cast (MeanVarParametrization , parameters )
243+ parameters = cast (NormalMeanVarParametrization , parameters )
233244
234245 z = (x - parameters .mu ) / (parameters .sigma * np .sqrt (2 ))
235246 return cast (npt .NDArray [np .float64 ], 0.5 * (1 + erf (z )))
@@ -262,7 +273,7 @@ def normal_ppf(
262273 if np .any ((p < 0 ) | (p > 1 )):
263274 raise ValueError ("Probability must be in [0, 1]" )
264275
265- parameters = cast (MeanVarParametrization , parameters )
276+ parameters = cast (NormalMeanVarParametrization , parameters )
266277
267278 return cast (
268279 npt .NDArray [np .float64 ],
@@ -289,31 +300,31 @@ def normal_char_func(
289300 npt.NDArray[np.complex128]
290301 Characteristic function values at points x
291302 """
292- parameters = cast (MeanVarParametrization , parameters )
303+ parameters = cast (NormalMeanVarParametrization , parameters )
293304
294305 sigma = parameters .sigma
295306 mu = parameters .mu
296307 return cast (npt .NDArray [np .complex128 ], np .exp (1j * mu * x - 0.5 * (sigma * x ) ** 2 ))
297308
298- def mean_func (parameters : Parametrization , __ : Any = None ) -> float :
309+ def mean_func (parameters : Parametrization , _ : Any = None ) -> float :
299310 """Mean of normal distribution."""
300- parameters = cast (MeanVarParametrization , parameters )
311+ parameters = cast (NormalMeanVarParametrization , parameters )
301312 return parameters .mu
302313
303- def var_func (parameters : Parametrization , __ : Any = None ) -> float :
314+ def var_func (parameters : Parametrization , _ : Any = None ) -> float :
304315 """Variance of normal distribution."""
305- parameters = cast (MeanVarParametrization , parameters )
316+ parameters = cast (NormalMeanVarParametrization , parameters )
306317 return parameters .sigma ** 2
307318
308- def skew_func (_ : Parametrization , __ : Any = None ) -> int :
319+ def skew_func (_1 : Parametrization , _2 : Any = None ) -> int :
309320 """Skewness of normal distribution (always 0)."""
310321 return 0
311322
312- def raw_kurt_func (_ : Parametrization , __ : Any = None ) -> int :
323+ def raw_kurt_func (_1 : Parametrization , _2 : Any = None ) -> int :
313324 """Raw kurtosis of normal distribution (always 3)."""
314325 return 3
315326
316- def ex_kurt_func (_ : Parametrization , __ : Any ) -> int :
327+ def ex_kurt_func (_1 : Parametrization , _2 : Any = None ) -> int :
317328 """Excess kurtosis of normal distribution (always 0)."""
318329 return 0
319330
@@ -336,8 +347,8 @@ def ex_kurt_func(_: Parametrization, __: Any) -> int:
336347 )
337348 Normal .__doc__ = NORMAL_DOC
338349
339- parametrization (family = Normal , name = "meanVar" )(MeanVarParametrization )
340- parametrization (family = Normal , name = "meanPrec" )(MeanPrecParametrization )
341- parametrization (family = Normal , name = "exponential" )(ExpParametrization )
350+ parametrization (family = Normal , name = "meanVar" )(NormalMeanVarParametrization )
351+ parametrization (family = Normal , name = "meanPrec" )(NormalMeanPrecParametrization )
352+ parametrization (family = Normal , name = "exponential" )(NormalExpParametrization )
342353
343354 ParametricFamilyRegister .register (Normal )
0 commit comments