@@ -530,6 +530,7 @@ def predict(
530530 self ,
531531 X_pred : np .ndarray | pd .DataFrame | pd .Series ,
532532 extend_idata : bool = True ,
533+ predictions : bool = False ,
533534 ** kwargs ,
534535 ) -> np .ndarray :
535536 """
@@ -559,7 +560,7 @@ def predict(
559560 """
560561
561562 posterior_predictive_samples = self .sample_posterior_predictive (
562- X_pred , extend_idata , combined = False , ** kwargs
563+ X_pred , extend_idata , predictions , combined = False , ** kwargs
563564 )
564565
565566 if self .output_var not in posterior_predictive_samples :
@@ -624,7 +625,7 @@ def sample_prior_predictive(
624625
625626 return prior_predictive_samples
626627
627- def sample_posterior_predictive (self , X_pred , extend_idata , combined , ** kwargs ):
628+ def sample_posterior_predictive (self , X_pred , extend_idata , predictions , combined , ** kwargs ):
628629 """
629630 Sample from the model's posterior predictive distribution.
630631
@@ -646,12 +647,12 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
646647 self ._data_setter (X_pred )
647648
648649 with self .model : # sample with new input data
649- post_pred = pm .sample_posterior_predictive (self .idata , ** kwargs )
650+ post_pred = pm .sample_posterior_predictive (self .idata , predictions = predictions , ** kwargs )
650651 if extend_idata :
651652 self .idata .extend (post_pred , join = "right" )
652653
653- # Determine the correct group dynamically
654- group_name = "predictions" if kwargs . get ( " predictions" , False ) else "posterior_predictive"
654+ # Determine the correct group
655+ group_name = "predictions" if predictions else "posterior_predictive"
655656
656657 posterior_predictive_samples = az .extract (
657658 post_pred , group_name , combined = combined
@@ -703,6 +704,7 @@ def predict_posterior(
703704 X_pred : np .ndarray | pd .DataFrame | pd .Series ,
704705 extend_idata : bool = True ,
705706 combined : bool = True ,
707+ predictions : bool = False ,
706708 ** kwargs ,
707709 ) -> xr .DataArray :
708710 """
@@ -726,7 +728,7 @@ def predict_posterior(
726728
727729 X_pred = self ._validate_data (X_pred )
728730 posterior_predictive_samples = self .sample_posterior_predictive (
729- X_pred , extend_idata , combined , ** kwargs
731+ X_pred , extend_idata , predictions , combined , ** kwargs
730732 )
731733
732734 if self .output_var not in posterior_predictive_samples :
0 commit comments