1+ # Copyright 2024 The PyMC Labs Developers
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
114"""
215Defines generic PyMC ModelBuilder class and subclasses for
316
@@ -41,8 +54,8 @@ class ModelBuilder(pm.Model):
4154 >>> class MyToyModel(ModelBuilder):
4255 ... def build_model(self, X, y, coords):
4356 ... with self:
44- ... X_ = pm.MutableData (name="X", value=X)
45- ... y_ = pm.MutableData (name="y", value=y)
57+ ... X_ = pm.Data (name="X", value=X)
58+ ... y_ = pm.Data (name="y", value=y)
4659 ... beta = pm.Normal("beta", mu=0, sigma=1, shape=X_.shape[1])
4760 ... sigma = pm.HalfNormal("sigma", sigma=1)
4861 ... mu = pm.Deterministic("mu", pm.math.dot(X_, beta))
@@ -59,13 +72,17 @@ class ModelBuilder(pm.Model):
5972 ... }
6073 ... )
6174 >>> model.fit(X, y)
62- Inference...
75+ <BLANKLINE>
76+ <BLANKLINE>
77+ Inference data...
6378 >>> X_new = rng.normal(loc=0, scale=1, size=(20,2))
6479 >>> model.predict(X_new)
65- Inference...
66- >>> model.score(X, y) # doctest: +NUMBER
67- r2 0.3
68- r2_std 0.0
80+ <BLANKLINE>
81+ Inference data...
82+ >>> model.score(X, y)
83+ <BLANKLINE>
84+ r2 0.390344
85+ r2_std 0.081135
6986 dtype: float64
7087 """
7188
@@ -99,10 +116,7 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
99116
100117 # Ensure random_seed is used in sample_prior_predictive() and
101118 # sample_posterior_predictive() if provided in sample_kwargs.
102- if "random_seed" in self .sample_kwargs :
103- random_seed = self .sample_kwargs ["random_seed" ]
104- else :
105- random_seed = None
119+ random_seed = self .sample_kwargs .get ("random_seed" , None )
106120
107121 self .build_model (X , y , coords )
108122 with self :
@@ -124,10 +138,17 @@ def predict(self, X):
124138
125139 """
126140
141+ # Ensure random_seed is used in sample_prior_predictive() and
142+ # sample_posterior_predictive() if provided in sample_kwargs.
143+ random_seed = self .sample_kwargs .get ("random_seed" , None )
144+
127145 self ._data_setter (X )
128146 with self : # sample with new input data
129147 post_pred = pm .sample_posterior_predictive (
130- self .idata , var_names = ["y_hat" , "mu" ], progressbar = False
148+ self .idata ,
149+ var_names = ["y_hat" , "mu" ],
150+ progressbar = False ,
151+ random_seed = random_seed ,
131152 )
132153 return post_pred
133154
@@ -180,7 +201,9 @@ class WeightedSumFitter(ModelBuilder):
180201 >>> y = np.asarray(sc['actual']).reshape((sc.shape[0], 1))
181202 >>> wsf = WeightedSumFitter(sample_kwargs={"progressbar": False})
182203 >>> wsf.fit(X,y)
183- Inference ...
204+ <BLANKLINE>
205+ <BLANKLINE>
206+ Inference data...
184207 """ # noqa: W605
185208
186209 def build_model (self , X , y , coords ):
@@ -190,8 +213,8 @@ def build_model(self, X, y, coords):
190213 with self :
191214 self .add_coords (coords )
192215 n_predictors = X .shape [1 ]
193- X = pm .MutableData ("X" , X , dims = ["obs_ind" , "coeffs" ])
194- y = pm .MutableData ("y" , y [:, 0 ], dims = "obs_ind" )
216+ X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
217+ y = pm .Data ("y" , y [:, 0 ], dims = "obs_ind" )
195218 # TODO: There we should allow user-specified priors here
196219 beta = pm .Dirichlet ("beta" , a = np .ones (n_predictors ), dims = "coeffs" )
197220 # beta = pm.Dirichlet(
@@ -236,7 +259,9 @@ class LinearRegression(ModelBuilder):
236259 ... 'obs_indx': np.arange(rd.shape[0])
237260 ... },
238261 ... )
239- Inference...
262+ <BLANKLINE>
263+ <BLANKLINE>
264+ Inference data...
240265 """ # noqa: W605
241266
242267 def build_model (self , X , y , coords ):
@@ -245,8 +270,8 @@ def build_model(self, X, y, coords):
245270 """
246271 with self :
247272 self .add_coords (coords )
248- X = pm .MutableData ("X" , X , dims = ["obs_ind" , "coeffs" ])
249- y = pm .MutableData ("y" , y [:, 0 ], dims = "obs_ind" )
273+ X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
274+ y = pm .Data ("y" , y [:, 0 ], dims = "obs_ind" )
250275 beta = pm .Normal ("beta" , 0 , 50 , dims = "coeffs" )
251276 sigma = pm .HalfNormal ("sigma" , 1 )
252277 mu = pm .Deterministic ("mu" , pm .math .dot (X , beta ), dims = "obs_ind" )
@@ -288,6 +313,8 @@ class InstrumentalVariableRegression(ModelBuilder):
288313 ... "eta": 2,
289314 ... "lkj_sd": 2,
290315 ... })
316+ <BLANKLINE>
317+ <BLANKLINE>
291318 Inference data...
292319 """
293320
0 commit comments