@@ -124,11 +124,18 @@ def _save_input_params(self, idata):
124124 def output_var (self ):
125125 return "output"
126126
127- def _data_setter (self , x : pd .Series , y : pd .Series = None ):
127+ def _data_setter (self , X : pd .Series | np .ndarray , y : pd .Series | np .ndarray = None ):
128+
128129 with self .model :
129- pm .set_data ({"x" : x .values })
130+
131+ X = X .values if isinstance (X , pd .Series ) else X .ravel ()
132+
133+ pm .set_data ({"x" : X })
134+
130135 if y is not None :
131- pm .set_data ({"y_data" : y .values })
136+ y = y .values if isinstance (y , pd .Series ) else y .ravel ()
137+
138+ pm .set_data ({"y_data" : y })
132139
133140 @property
134141 def _serializable_model_config (self ):
@@ -177,8 +184,8 @@ def test_save_load(fitted_model_instance):
177184 assert fitted_model_instance .id == test_builder2 .id
178185 x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
179186 prediction_data = pd .DataFrame ({"input" : x_pred })
180- pred1 = fitted_model_instance .predict (prediction_data ["input" ])
181- pred2 = test_builder2 .predict (prediction_data ["input" ])
187+ pred1 = fitted_model_instance .predict (prediction_data [[ "input" ] ])
188+ pred2 = test_builder2 .predict (prediction_data [[ "input" ] ])
182189 assert pred1 .shape == pred2 .shape
183190 temp .close ()
184191
@@ -205,7 +212,7 @@ def test_empty_sampler_config_fit(toy_X, toy_y):
205212
206213def test_fit (fitted_model_instance ):
207214 prediction_data = pd .DataFrame ({"input" : np .random .uniform (low = 0 , high = 1 , size = 100 )})
208- pred = fitted_model_instance .predict (prediction_data ["input" ])
215+ pred = fitted_model_instance .predict (prediction_data [[ "input" ] ])
209216 post_pred = fitted_model_instance .sample_posterior_predictive (
210217 prediction_data ["input" ], extend_idata = True , combined = True
211218 )
@@ -223,7 +230,7 @@ def test_fit_no_y(toy_X):
223230def test_predict (fitted_model_instance ):
224231 x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
225232 prediction_data = pd .DataFrame ({"input" : x_pred })
226- pred = fitted_model_instance .predict (prediction_data ["input" ])
233+ pred = fitted_model_instance .predict (prediction_data [[ "input" ] ])
227234 # Perform elementwise comparison using numpy
228235 assert isinstance (pred , np .ndarray )
229236 assert len (pred ) > 0
@@ -256,13 +263,12 @@ def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idat
256263
257264 prediction_data = pd .DataFrame ({"input" : x_pred })
258265 if group == "prior_predictive" :
259- prediction_method = fitted_model_instance .sample_prior_predictive
266+ pred = fitted_model_instance .sample_prior_predictive ( prediction_data [ "input" ], combined = False , extend_idata = extend_idata )
260267 else : # group == "posterior_predictive":
261- prediction_method = fitted_model_instance .sample_posterior_predictive
262-
263- pred = prediction_method (prediction_data ["input" ], combined = False , extend_idata = extend_idata )
268+ pred = fitted_model_instance .sample_posterior_predictive (prediction_data ["input" ], combined = False , predictions = False , extend_idata = extend_idata )
264269
265270 pred_unstacked = pred [output_var ].values
271+
266272 idata_now = fitted_model_instance .idata [group ][output_var ].values
267273
268274 if extend_idata :
@@ -320,9 +326,40 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
320326
321327 # Run prediction with predictions=True or False
322328 fitted_model_instance .predict (
323- prediction_data ["input" ],
329+ X_pred = prediction_data [["input" ]],
330+ extend_idata = True ,
331+ predictions = predictions ,
332+ )
333+
334+ pp_after = fitted_model_instance .idata .posterior_predictive [output_var ].values
335+
336+ # Check predictions group presence
337+ if predictions :
338+ assert "predictions" in fitted_model_instance .idata .groups ()
339+ # Posterior predictive should remain unchanged
340+ np .testing .assert_array_equal (pp_before , pp_after )
341+ else :
342+ assert "predictions" not in fitted_model_instance .idata .groups ()
343+ # Posterior predictive should be updated
344+ assert not np .array_equal (pp_before , pp_after )
345+
346+ @pytest .mark .parametrize ("predictions" , [True , False ])
347+ def test_predict_posterior_respects_predictions_flag (fitted_model_instance , predictions ):
348+ x_pred = np .random .uniform (0 , 1 , 100 )
349+ prediction_data = pd .DataFrame ({"input" : x_pred })
350+ output_var = fitted_model_instance .output_var
351+
352+ # Snapshot the original posterior_predictive values
353+ pp_before = fitted_model_instance .idata .posterior_predictive [output_var ].values .copy ()
354+
355+ # Ensure 'predictions' group is not present initially
356+ assert "predictions" not in fitted_model_instance .idata .groups ()
357+
358+ # Run prediction with predictions=True or False
359+ fitted_model_instance .predict_posterior (
360+ X_pred = prediction_data [["input" ]],
324361 extend_idata = True ,
325- combined = False ,
362+ combined = True ,
326363 predictions = predictions ,
327364 )
328365
@@ -336,4 +373,4 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
336373 else :
337374 assert "predictions" not in fitted_model_instance .idata .groups ()
338375 # Posterior predictive should be updated
339- np .testing . assert_array_not_equal (pp_before , pp_after )
376+ assert not np .array_equal (pp_before , pp_after )
0 commit comments