@@ -72,13 +72,17 @@ class ModelBuilder(pm.Model):
7272 ... }
7373 ... )
7474 >>> model.fit(X, y)
75- Inference...
75+ <BLANKLINE>
76+ <BLANKLINE>
77+ Inference data...
7678 >>> X_new = rng.normal(loc=0, scale=1, size=(20,2))
7779 >>> model.predict(X_new)
78- Inference...
79- >>> model.score(X, y) # doctest: +NUMBER
80- r2 0.3
81- r2_std 0.0
80+ <BLANKLINE>
81+ Inference data...
82+ >>> model.score(X, y)
83+ <BLANKLINE>
84+ r2 0.390344
85+ r2_std 0.081135
8286 dtype: float64
8387 """
8488
@@ -112,10 +116,7 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
112116
113117 # Ensure random_seed is used in sample_prior_predictive() and
114118 # sample_posterior_predictive() if provided in sample_kwargs.
115- if "random_seed" in self .sample_kwargs :
116- random_seed = self .sample_kwargs ["random_seed" ]
117- else :
118- random_seed = None
119+ random_seed = self .sample_kwargs .get ("random_seed" , None )
119120
120121 self .build_model (X , y , coords )
121122 with self :
@@ -137,10 +138,17 @@ def predict(self, X):
137138
138139 """
139140
141+ # Ensure random_seed is used in sample_prior_predictive() and
142+ # sample_posterior_predictive() if provided in sample_kwargs.
143+ random_seed = self .sample_kwargs .get ("random_seed" , None )
144+
140145 self ._data_setter (X )
141146 with self : # sample with new input data
142147 post_pred = pm .sample_posterior_predictive (
143- self .idata , var_names = ["y_hat" , "mu" ], progressbar = False
148+ self .idata ,
149+ var_names = ["y_hat" , "mu" ],
150+ progressbar = False ,
151+ random_seed = random_seed ,
144152 )
145153 return post_pred
146154
@@ -193,7 +201,9 @@ class WeightedSumFitter(ModelBuilder):
193201 >>> y = np.asarray(sc['actual']).reshape((sc.shape[0], 1))
194202 >>> wsf = WeightedSumFitter(sample_kwargs={"progressbar": False})
195203 >>> wsf.fit(X,y)
196- Inference ...
204+ <BLANKLINE>
205+ <BLANKLINE>
206+ Inference data...
197207 """ # noqa: W605
198208
199209 def build_model (self , X , y , coords ):
@@ -249,7 +259,9 @@ class LinearRegression(ModelBuilder):
249259 ... 'obs_indx': np.arange(rd.shape[0])
250260 ... },
251261 ... )
252- Inference...
262+ <BLANKLINE>
263+ <BLANKLINE>
264+ Inference data...
253265 """ # noqa: W605
254266
255267 def build_model (self , X , y , coords ):
@@ -301,6 +313,8 @@ class InstrumentalVariableRegression(ModelBuilder):
301313 ... "eta": 2,
302314 ... "lkj_sd": 2,
303315 ... })
316+ <BLANKLINE>
317+ <BLANKLINE>
304318 Inference data...
305319 """
306320
0 commit comments