Skip to content

Commit fc27d30

Browse files
committed
chore(tests): use dynamic tolerances based on np.finfo for pdf, lpdf of MixtureModel
1 parent 35df16f commit fc27d30

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

rework_tests/unit/core/test_mixture.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ def test_init_with_specified_weights(self, exp_components: tuple[Exponential, Ex
6868
model = MixtureModel(components=exp_components, weights=weights, dtype=dtype)
6969
expected_n_components = 2
7070
assert model.n_components == expected_n_components
71-
if dtype == np.float64:
72-
np.testing.assert_allclose(model.weights, weights)
73-
np.testing.assert_allclose(np.exp(model.log_weights), weights, atol=1e-9)
71+
72+
atol = np.finfo(dtype).eps
73+
np.testing.assert_allclose(model.weights, weights, rtol=atol)
74+
np.testing.assert_allclose(np.exp(model.log_weights), weights, rtol=atol)
7475

7576
@pytest.mark.parametrize(
7677
"invalid_weights, error_msg",
@@ -134,11 +135,11 @@ def test_init_casts_component_dtypes(self, dtype):
134135
def test_init_does_not_recreate_components_with_correct_dtype(self, dtype):
135136
"""Tests that components with the correct dtype are not recreated."""
136137

137-
comp_f32 = Exponential(loc=0.0, rate=1.0, dtype=dtype)
138+
comp = Exponential(loc=0.0, rate=1.0, dtype=dtype)
138139

139-
original_id = id(comp_f32)
140+
original_id = id(comp)
140141

141-
mixture = MixtureModel(components=[comp_f32], dtype=dtype)
142+
mixture = MixtureModel(components=[comp], dtype=dtype)
142143

143144
assert id(mixture.components[0]) == original_id
144145

@@ -304,24 +305,26 @@ def test_pdf_calculation(self, mixture_model: MixtureModel, X):
304305
expected_pdf = w1 * c1.pdf(X) + w2 * c2.pdf(X)
305306
calculated_pdf = mixture_model.pdf(X)
306307

307-
assert isinstance(calculated_pdf, np.ndarray)
308308
assert calculated_pdf.dtype == dtype
309-
if dtype == np.float64:
310-
np.testing.assert_allclose(calculated_pdf, expected_pdf)
309+
if not np.isscalar(X):
310+
assert isinstance(calculated_pdf, np.ndarray)
311+
312+
np.testing.assert_allclose(calculated_pdf, expected_pdf, rtol=np.finfo(dtype).eps)
311313

312314
@pytest.mark.parametrize("X", [1.5, [1.5], np.array([1.0, 1.5, 6.0])])
313315
def test_lpdf_calculation(self, mixture_model: MixtureModel, X):
314316
"""Tests the LPDF calculation against the definition."""
315317

316318
dtype = mixture_model.dtype
319+
c1, c2 = mixture_model.components
320+
w1, w2 = mixture_model.weights
317321

318-
expected_lpdf = np.log(mixture_model.pdf(X))
322+
expected_lpdf = np.log(w1 * c1.pdf(X) + w2 * c2.pdf(X))
319323
calculated_lpdf = mixture_model.lpdf(X)
320324

321325
assert isinstance(calculated_lpdf, np.ndarray)
322326
assert calculated_lpdf.dtype == dtype
323-
if dtype == np.float64:
324-
np.testing.assert_allclose(calculated_lpdf, expected_lpdf)
327+
np.testing.assert_allclose(calculated_lpdf, expected_lpdf, rtol=np.finfo(dtype).eps)
325328

326329
def test_loglikelihood_calculation(self, mixture_model: MixtureModel):
327330
"""Tests that loglikelihood is the sum of LPDF values."""
@@ -353,7 +356,11 @@ def test_generate_returns_correct_size(self, mixture_model: MixtureModel):
353356
def test_generate_with_size_zero(self, mixture_model):
354357
"""Tests that generating with size = 0 returns an empty array."""
355358

356-
assert len(mixture_model.generate(0)) == 0
359+
dtype = mixture_model.dtype
360+
361+
samples = mixture_model.generate(0)
362+
assert len(samples) == 0
363+
assert samples.dtype == dtype
357364

358365
@pytest.mark.parametrize("size", [-1, -10])
359366
def test_generate_with_negative_size(self, mixture_model: MixtureModel, size: int):

0 commit comments

Comments
 (0)