@@ -110,7 +110,7 @@ def _data_setter(self, X) -> None:
110110 pm .set_data ({"X" : X })
111111
112112 def fit (self , X , y , coords : Optional [Dict [str , Any ]] = None ) -> None :
113- """Draw samples fromposterior , prior predictive, and posterior predictive
113+ """Draw samples from posterior , prior predictive, and posterior predictive
114114 distributions, placing them in the model's idata attribute.
115115 """
116116
@@ -380,12 +380,19 @@ def fit(self, X, Z, y, t, coords, priors):
380380 """Draw samples from posterior, prior predictive, and posterior predictive
381381 distributions.
382382 """
383+
384+ # Ensure random_seed is used in sample_prior_predictive() and
385+ # sample_posterior_predictive() if provided in sample_kwargs.
386+ random_seed = self .sample_kwargs .get ("random_seed" , None )
387+
383388 self .build_model (X , Z , y , t , coords , priors )
384389 with self :
385390 self .idata = pm .sample (** self .sample_kwargs )
386- self .idata .extend (pm .sample_prior_predictive ())
391+ self .idata .extend (pm .sample_prior_predictive (random_seed = random_seed ))
387392 self .idata .extend (
388- pm .sample_posterior_predictive (self .idata , progressbar = False )
393+ pm .sample_posterior_predictive (
394+ self .idata , progressbar = False , random_seed = random_seed
395+ )
389396 )
390397 return self .idata
391398
0 commit comments