3737 CauchyRVType ,
3838 HalfCauchyRV ,
3939 HalfCauchyRVType ,
40+ BetaRV ,
41+ BetaRVType ,
42+ BinomialRV ,
43+ BinomialRVType ,
44+ PoissonRV ,
45+ PoissonRVType ,
46+ DirichletRV ,
47+ DirichletRVType ,
48+ BernoulliRV ,
49+ BernoulliRVType ,
50+ BetaBinomialRV ,
51+ BetaBinomialRVType ,
52+ CategoricalRV ,
53+ CategoricalRVType ,
54+ MultinomialRV ,
55+ MultinomialRVType ,
4056)
4157from .opt import FunctionGraph
4258from .ops import RandomVariable
@@ -197,6 +213,110 @@ def _convert_rv_to_dist_HalfCauchy(op, rv):
197213 return pm .HalfCauchy , params
198214
199215
216+ @convert_dist_to_rv .register (pm .Beta , object )
217+ def convert_dist_to_rv_Beta (dist , rng ):
218+ size = dist .shape .astype (int )[BetaRV .ndim_supp :]
219+ res = BetaRV (dist .alpha , dist .beta , size = size , rng = rng )
220+ return res
221+
222+
223+ @_convert_rv_to_dist .register (BetaRVType , Apply )
224+ def _convert_rv_to_dist_Beta (op , rv ):
225+ params = {"alpha" : rv .inputs [0 ], "beta" : rv .inputs [1 ]}
226+ return pm .Beta , params
227+
228+
229+ @convert_dist_to_rv .register (pm .Binomial , object )
230+ def convert_dist_to_rv_Binomial (dist , rng ):
231+ size = dist .shape .astype (int )[BinomialRV .ndim_supp :]
232+ res = BinomialRV (dist .n , dist .p , size = size , rng = rng )
233+ return res
234+
235+
236+ @_convert_rv_to_dist .register (BinomialRVType , Apply )
237+ def _convert_rv_to_dist_Binomial (op , rv ):
238+ params = {"n" : rv .inputs [0 ], "p" : rv .inputs [1 ]}
239+ return pm .Binomial , params
240+
241+
242+ @convert_dist_to_rv .register (pm .Poisson , object )
243+ def convert_dist_to_rv_Poisson (dist , rng ):
244+ size = dist .shape .astype (int )[PoissonRV .ndim_supp :]
245+ res = PoissonRV (dist .mu , size = size , rng = rng )
246+ return res
247+
248+
249+ @_convert_rv_to_dist .register (PoissonRVType , Apply )
250+ def _convert_rv_to_dist_Poisson (op , rv ):
251+ params = {"mu" : rv .inputs [0 ]}
252+ return pm .Poisson , params
253+
254+
255+ @convert_dist_to_rv .register (pm .Dirichlet , object )
256+ def convert_dist_to_rv_Dirichlet (dist , rng ):
257+ size = dist .shape .astype (int )[DirichletRV .ndim_supp :]
258+ res = DirichletRV (dist .a , size = size , rng = rng )
259+ return res
260+
261+
262+ @_convert_rv_to_dist .register (DirichletRVType , Apply )
263+ def _convert_rv_to_dist_Dirichlet (op , rv ):
264+ params = {"a" : rv .inputs [0 ]}
265+ return pm .Dirichlet , params
266+
267+
268+ @convert_dist_to_rv .register (pm .Bernoulli , object )
269+ def convert_dist_to_rv_Bernoulli (dist , rng ):
270+ size = dist .shape .astype (int )[BernoulliRV .ndim_supp :]
271+ res = BernoulliRV (dist .p , size = size , rng = rng )
272+ return res
273+
274+
275+ @_convert_rv_to_dist .register (BernoulliRVType , Apply )
276+ def _convert_rv_to_dist_Bernoulli (op , rv ):
277+ params = {"p" : rv .inputs [0 ]}
278+ return pm .Bernoulli , params
279+
280+
281+ @convert_dist_to_rv .register (pm .BetaBinomial , object )
282+ def convert_dist_to_rv_BetaBinomial (dist , rng ):
283+ size = dist .shape .astype (int )[BetaBinomialRV .ndim_supp :]
284+ res = BetaBinomialRV (dist .n , dist .alpha , dist .beta , size = size , rng = rng )
285+ return res
286+
287+
288+ @_convert_rv_to_dist .register (BetaBinomialRVType , Apply )
289+ def _convert_rv_to_dist_BetaBinomial (op , rv ):
290+ params = {"n" : rv .inputs [0 ], "alpha" : rv .inputs [1 ], "beta" : rv .inputs [2 ]}
291+ return pm .BetaBinomial , params
292+
293+
294+ @convert_dist_to_rv .register (pm .Categorical , object )
295+ def convert_dist_to_rv_Categorical (dist , rng ):
296+ size = dist .shape .astype (int )[CategoricalRV .ndim_supp :]
297+ res = CategoricalRV (dist .p , size = size , rng = rng )
298+ return res
299+
300+
301+ @_convert_rv_to_dist .register (CategoricalRVType , Apply )
302+ def _convert_rv_to_dist_Categorical (op , rv ):
303+ params = {"p" : rv .inputs [0 ]}
304+ return pm .Categorical , params
305+
306+
307+ @convert_dist_to_rv .register (pm .Multinomial , object )
308+ def convert_dist_to_rv_Multinomial (dist , rng ):
309+ size = dist .shape .astype (int )[MultinomialRV .ndim_supp :]
310+ res = MultinomialRV (dist .n , dist .p , size = size , rng = rng )
311+ return res
312+
313+
314+ @_convert_rv_to_dist .register (MultinomialRVType , Apply )
315+ def _convert_rv_to_dist_Multinomial (op , rv ):
316+ params = {"n" : rv .inputs [0 ], "p" : rv .inputs [1 ]}
317+ return pm .Multinomial , params
318+
319+
200320# TODO: More RV conversions!
201321
202322
@@ -207,9 +327,17 @@ def pymc3_var_to_rv(pm_var, rand_state=None):
207327 new_rv .name = pm_var .name
208328
209329 if isinstance (pm_var , pm .model .ObservedRV ):
210- obs = tt .as_tensor_variable (pm_var .observations )
330+ obs = pm_var .observations
331+ # For some reason, the observations can be float when the RV's dtype is
332+ # not.
333+ if obs .dtype != pm_var .dtype :
334+ obs = obs .astype (pm_var .dtype )
335+
336+ obs = tt .as_tensor_variable (obs )
337+
211338 if getattr (obs , "cached" , False ):
212339 obs = obs .clone ()
340+
213341 new_rv = observed (obs , new_rv )
214342
215343 # Let's attempt to fix the PyMC3 broadcastable dims "oracle" issue,
0 commit comments