Skip to content

Commit df89b8a

Browse files
authored
Update nv_generator.py
1 parent 16dabfa commit df89b8a

File tree

1 file changed

+3
-22
lines changed

1 file changed

+3
-22
lines changed

src/generators/nv_generator.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class NVGenerator(AbstractGenerator):
1010

1111
@staticmethod
12-
def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
12+
def generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
1313
"""Generate a sample of given size. Classical form of NVM
1414
1515
Args:
@@ -27,25 +27,6 @@ def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
2727
raise ValueError("Mixture must be NormalMeanMixtures")
2828
mixing_values = mixture.params.distribution.rvs(size=size)
2929
normal_values = scipy.stats.norm.rvs(size=size)
30+
if mixture.mixture_form == "canonical":
31+
return mixture.params.alpha + (mixing_values ** 0.5) * normal_values
3032
return mixture.params.alpha + mixture.params.gamma * (mixing_values**0.5) * normal_values
31-
32-
@staticmethod
33-
def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray:
34-
"""Generate a sample of given size. Canonical form of NVM
35-
36-
Args:
37-
mixture: Normal Variance Mixtures
38-
size: length of sample
39-
40-
Returns: sample of given size
41-
42-
Raises:
43-
ValueError: If mixture type is not Normal Variance Mixtures
44-
45-
"""
46-
47-
if not isinstance(mixture, NormalVarianceMixtures):
48-
raise ValueError("Mixture must be NormalMeanMixtures")
49-
mixing_values = mixture.params.distribution.rvs(size=size)
50-
normal_values = scipy.stats.norm.rvs(size=size)
51-
return mixture.params.alpha + (mixing_values**0.5) * normal_values

0 commit comments

Comments
 (0)