@@ -68,9 +68,10 @@ def test_init_with_specified_weights(self, exp_components: tuple[Exponential, Ex
6868 model = MixtureModel (components = exp_components , weights = weights , dtype = dtype )
6969 expected_n_components = 2
7070 assert model .n_components == expected_n_components
71- if dtype == np .float64 :
72- np .testing .assert_allclose (model .weights , weights )
73- np .testing .assert_allclose (np .exp (model .log_weights ), weights , atol = 1e-9 )
71+
72+ atol = np .finfo (dtype ).eps
73+ np .testing .assert_allclose (model .weights , weights , rtol = atol )
74+ np .testing .assert_allclose (np .exp (model .log_weights ), weights , rtol = atol )
7475
7576 @pytest .mark .parametrize (
7677 "invalid_weights, error_msg" ,
@@ -134,11 +135,11 @@ def test_init_casts_component_dtypes(self, dtype):
134135 def test_init_does_not_recreate_components_with_correct_dtype (self , dtype ):
135136 """Tests that components with the correct dtype are not recreated."""
136137
137- comp_f32 = Exponential (loc = 0.0 , rate = 1.0 , dtype = dtype )
138+ comp = Exponential (loc = 0.0 , rate = 1.0 , dtype = dtype )
138139
139- original_id = id (comp_f32 )
140+ original_id = id (comp )
140141
141- mixture = MixtureModel (components = [comp_f32 ], dtype = dtype )
142+ mixture = MixtureModel (components = [comp ], dtype = dtype )
142143
143144 assert id (mixture .components [0 ]) == original_id
144145
@@ -304,24 +305,26 @@ def test_pdf_calculation(self, mixture_model: MixtureModel, X):
304305 expected_pdf = w1 * c1 .pdf (X ) + w2 * c2 .pdf (X )
305306 calculated_pdf = mixture_model .pdf (X )
306307
307- assert isinstance (calculated_pdf , np .ndarray )
308308 assert calculated_pdf .dtype == dtype
309- if dtype == np .float64 :
310- np .testing .assert_allclose (calculated_pdf , expected_pdf )
309+ if not np .isscalar (X ):
310+ assert isinstance (calculated_pdf , np .ndarray )
311+
312+ np .testing .assert_allclose (calculated_pdf , expected_pdf , rtol = np .finfo (dtype ).eps )
311313
312314 @pytest .mark .parametrize ("X" , [1.5 , [1.5 ], np .array ([1.0 , 1.5 , 6.0 ])])
313315 def test_lpdf_calculation (self , mixture_model : MixtureModel , X ):
314316 """Tests the LPDF calculation against the definition."""
315317
316318 dtype = mixture_model .dtype
319+ c1 , c2 = mixture_model .components
320+ w1 , w2 = mixture_model .weights
317321
318- expected_lpdf = np .log (mixture_model .pdf (X ))
322+ expected_lpdf = np .log (w1 * c1 . pdf ( X ) + w2 * c2 .pdf (X ))
319323 calculated_lpdf = mixture_model .lpdf (X )
320324
321325 assert isinstance (calculated_lpdf , np .ndarray )
322326 assert calculated_lpdf .dtype == dtype
323- if dtype == np .float64 :
324- np .testing .assert_allclose (calculated_lpdf , expected_lpdf )
327+ np .testing .assert_allclose (calculated_lpdf , expected_lpdf , rtol = np .finfo (dtype ).eps )
325328
326329 def test_loglikelihood_calculation (self , mixture_model : MixtureModel ):
327330 """Tests that loglikelihood is the sum of LPDF values."""
@@ -353,7 +356,11 @@ def test_generate_returns_correct_size(self, mixture_model: MixtureModel):
353356 def test_generate_with_size_zero (self , mixture_model ):
354357 """Tests that generating with size = 0 returns an empty array."""
355358
356- assert len (mixture_model .generate (0 )) == 0
359+ dtype = mixture_model .dtype
360+
361+ samples = mixture_model .generate (0 )
362+ assert len (samples ) == 0
363+ assert samples .dtype == dtype
357364
358365 @pytest .mark .parametrize ("size" , [- 1 , - 10 ])
359366 def test_generate_with_negative_size (self , mixture_model : MixtureModel , size : int ):
0 commit comments