@@ -313,7 +313,8 @@ def test_id():
313313
314314
315315@pytest .mark .parametrize ("predictions" , [True , False ])
316- def test_predict_respects_predictions_flag (fitted_model_instance , predictions ):
316+ @pytest .mark .parametrize ("predict_method" , ["predict" , "predict_posterior" ])
317+ def test_predict_method_respects_predictions_flag (fitted_model_instance , predictions , predict_method ):
317318 x_pred = np .random .uniform (0 , 1 , 100 )
318319 prediction_data = pd .DataFrame ({"input" : x_pred })
319320 output_var = fitted_model_instance .output_var
@@ -325,43 +326,18 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
325326 assert "predictions" not in fitted_model_instance .idata .groups ()
326327
327328 # Run prediction with predictions=True or False
328- fitted_model_instance .predict (
329- X_pred = prediction_data [["input" ]],
330- extend_idata = True ,
331- predictions = predictions ,
332- )
333-
334- pp_after = fitted_model_instance .idata .posterior_predictive [output_var ].values
335-
336- # Check predictions group presence
337- if predictions :
338- assert "predictions" in fitted_model_instance .idata .groups ()
339- # Posterior predictive should remain unchanged
340- np .testing .assert_array_equal (pp_before , pp_after )
341- else :
342- assert "predictions" not in fitted_model_instance .idata .groups ()
343- # Posterior predictive should be updated
344- assert not np .array_equal (pp_before , pp_after )
345-
346- @pytest .mark .parametrize ("predictions" , [True , False ])
347- def test_predict_posterior_respects_predictions_flag (fitted_model_instance , predictions ):
348- x_pred = np .random .uniform (0 , 1 , 100 )
349- prediction_data = pd .DataFrame ({"input" : x_pred })
350- output_var = fitted_model_instance .output_var
351-
352- # Snapshot the original posterior_predictive values
353- pp_before = fitted_model_instance .idata .posterior_predictive [output_var ].values .copy ()
354-
355- # Ensure 'predictions' group is not present initially
356- assert "predictions" not in fitted_model_instance .idata .groups ()
357-
358- # Run prediction with predictions=True or False
359- fitted_model_instance .predict_posterior (
360- X_pred = prediction_data [["input" ]],
361- extend_idata = True ,
362- combined = True ,
363- predictions = predictions ,
364- )
329+ if predict_method == "predict" :
330+ fitted_model_instance .predict (
331+ X_pred = prediction_data [["input" ]],
332+ extend_idata = True ,
333+ predictions = predictions ,
334+ )
335+ else :# predict_method == "predict_posterior":
336+ fitted_model_instance .predict_posterior (
337+ X_pred = prediction_data [["input" ]],
338+ extend_idata = True ,
339+ predictions = predictions ,
340+ )
365341
366342 pp_after = fitted_model_instance .idata .posterior_predictive [output_var ].values
367343
@@ -374,3 +350,4 @@ def test_predict_posterior_respects_predictions_flag(fitted_model_instance, pred
374350 assert "predictions" not in fitted_model_instance .idata .groups ()
375351 # Posterior predictive should be updated
376352 assert not np .array_equal (pp_before , pp_after )
353+
0 commit comments