Skip to content

Commit e7a4c59

Browse files
Merge pull request #99 from brandonwillard/add-beta-binom-pymc3
Add basic Theano and PyMC3 distributions
2 parents 278f625 + 1e8790c commit e7a4c59

File tree

8 files changed

+270
-42
lines changed

8 files changed

+270
-42
lines changed

setup.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,15 @@ def get_long_description():
3232
author_email=AUTHOR_EMAIL,
3333
url=URL,
3434
install_requires=[
35-
"scipy>=1.2.0",
35+
"scipy>=1.4.0",
3636
"Theano>=1.0.4",
37-
"tf-nightly-2.0-preview==2.0.0.dev20191002",
38-
"tf-nightly==2.1.0.dev20191003",
39-
"tf-estimator-nightly==2.0.0.dev2019100301",
40-
"tensorflow-estimator-2.0-preview==1.14.0.dev2019090801",
41-
"tfp-nightly==0.9.0.dev20191003",
37+
"tf-estimator-nightly==2.1.0.dev2020012309",
38+
"tf-nightly==2.2.0.dev20200201",
39+
"tfp-nightly==0.10.0.dev20200201",
4240
"multipledispatch>=0.6.0",
4341
"logical-unification>=0.4.3",
4442
"miniKanren>=1.0.1",
45-
"etuples>=0.3.1",
43+
"etuples>=0.3.2",
4644
"cons>=0.4.0",
4745
"toolz>=0.9.0",
4846
"cachetools",

symbolic_pymc/meta.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
from functools import partial
1010
from collections import OrderedDict
1111
from contextlib import contextmanager
12-
from collections.abc import Iterator, Mapping
12+
from collections.abc import Iterator, Mapping, Sequence
1313

1414
from unification import isvar, Var
1515

16+
from etuples.core import ExpressionTuple
17+
1618
from .utils import HashableNDArray
1719

1820
from multipledispatch import dispatch
@@ -74,20 +76,20 @@ def metatize(obj):
7476
return _metatize(obj)
7577

7678

77-
@dispatch((type(None), types.FunctionType, partial, str, dict))
79+
@dispatch((type(None), types.FunctionType, partial, str, Mapping))
7880
def _metatize(obj):
7981
return obj
8082

8183

82-
@_metatize.register((set, tuple))
84+
@_metatize.register((frozenset, tuple, ExpressionTuple))
8385
@cached(metatize_cache)
84-
def _metatize_set_tuple(obj):
86+
def _metatize_hashable_Sequence(obj):
8587
"""Convert elements of an iterable to meta objects."""
8688
return type(obj)([metatize(o) for o in obj])
8789

8890

89-
@_metatize.register(list)
90-
def _metatize_list(obj):
91+
@_metatize.register(Sequence)
92+
def _metatize_Sequence(obj):
9193
"""Convert elements of an iterable to meta objects."""
9294
return type(obj)([metatize(o) for o in obj])
9395

symbolic_pymc/relations/theano/conjugates.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,24 @@ def _create_normal_normal_goals():
3232
#
3333
# Create the pattern/form of the prior normal distribution
3434
#
35-
beta_name_lv = var("beta_name")
36-
beta_size_lv = var("beta_size")
37-
beta_rng_lv = var("beta_rng")
38-
a_lv = var("a")
39-
R_lv = var("R")
35+
beta_name_lv = var()
36+
beta_size_lv = var()
37+
beta_rng_lv = var()
38+
a_lv = var()
39+
R_lv = var()
4040
beta_prior_mt = mt.MvNormalRV(a_lv, R_lv, size=beta_size_lv, rng=beta_rng_lv, name=beta_name_lv)
4141

42-
y_name_lv = var("y_name")
43-
y_size_lv = var("y_size")
44-
y_rng_lv = var("y_rng")
45-
F_t_lv = var("f")
46-
V_lv = var("V")
42+
y_name_lv = var()
43+
y_size_lv = var()
44+
y_rng_lv = var()
45+
F_t_lv = var()
46+
V_lv = var()
4747
E_y_mt = mt.dot(F_t_lv, beta_prior_mt)
4848
Y_mt = mt.MvNormalRV(E_y_mt, V_lv, size=y_size_lv, rng=y_rng_lv, name=y_name_lv)
4949

5050
# The variable specifying the fixed sample value of the random variable
5151
# given by `Y_mt`
52-
obs_sample_mt = var("obs_sample")
52+
obs_sample_mt = var()
5353

5454
Y_obs_mt = mt.observed(obs_sample_mt, Y_mt)
5555

@@ -80,21 +80,21 @@ def _create_normal_normal_goals():
8080
def _create_normal_wishart_goals(): # pragma: no cover
8181
"""TODO."""
8282
# Create the pattern/form of the prior normal distribution
83-
Sigma_name_lv = var("Sigma_name")
84-
Sigma_size_lv = var("Sigma_size")
85-
Sigma_rng_lv = var("Sigma_rng")
86-
V_lv = var("V")
87-
n_lv = var("n")
83+
Sigma_name_lv = var()
84+
Sigma_size_lv = var()
85+
Sigma_rng_lv = var()
86+
V_lv = var()
87+
n_lv = var()
8888
Sigma_prior_mt = mt.WishartRV(V_lv, n_lv, Sigma_size_lv, Sigma_rng_lv, name=Sigma_name_lv)
8989

90-
y_name_lv = var("y_name")
91-
y_size_lv = var("y_size")
92-
y_rng_lv = var("y_rng")
93-
V_lv = var("V")
94-
f_mt = var("f")
90+
y_name_lv = var()
91+
y_size_lv = var()
92+
y_rng_lv = var()
93+
V_lv = var()
94+
f_mt = var()
9595
Y_mt = mt.MvNormalRV(f_mt, V_lv, y_size_lv, y_rng_lv, name=y_name_lv)
9696

97-
y_mt = var("y")
97+
y_mt = var()
9898
Y_obs_mt = mt.observed(y_mt, Y_mt)
9999

100100
n_post_mt = etuple(mt.add, n_lv, etuple(mt.Shape, Y_obs_mt))

symbolic_pymc/tensorflow/meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ class TFlowMetaOp(TFlowMetaSymbol):
434434
@classmethod
435435
def _metatize(cls, obj):
436436
"""Reformat inputs to match the OpDef."""
437-
new_input = obj._reconstruct_sequence_inputs(obj.op_def, obj.inputs, obj.node_def.attr)
437+
new_input = ops._reconstruct_sequence_inputs(obj.op_def, obj.inputs, obj.node_def.attr)
438438
new_args = [
439439
getattr(obj, s) if s != "inputs" else new_input for s in getattr(cls, "__props__", [])
440440
]

symbolic_pymc/theano/pymc3.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,22 @@
3737
CauchyRVType,
3838
HalfCauchyRV,
3939
HalfCauchyRVType,
40+
BetaRV,
41+
BetaRVType,
42+
BinomialRV,
43+
BinomialRVType,
44+
PoissonRV,
45+
PoissonRVType,
46+
DirichletRV,
47+
DirichletRVType,
48+
BernoulliRV,
49+
BernoulliRVType,
50+
BetaBinomialRV,
51+
BetaBinomialRVType,
52+
CategoricalRV,
53+
CategoricalRVType,
54+
MultinomialRV,
55+
MultinomialRVType,
4056
)
4157
from .opt import FunctionGraph
4258
from .ops import RandomVariable
@@ -197,6 +213,110 @@ def _convert_rv_to_dist_HalfCauchy(op, rv):
197213
return pm.HalfCauchy, params
198214

199215

216+
@convert_dist_to_rv.register(pm.Beta, object)
217+
def convert_dist_to_rv_Beta(dist, rng):
218+
size = dist.shape.astype(int)[BetaRV.ndim_supp :]
219+
res = BetaRV(dist.alpha, dist.beta, size=size, rng=rng)
220+
return res
221+
222+
223+
@_convert_rv_to_dist.register(BetaRVType, Apply)
224+
def _convert_rv_to_dist_Beta(op, rv):
225+
params = {"alpha": rv.inputs[0], "beta": rv.inputs[1]}
226+
return pm.Beta, params
227+
228+
229+
@convert_dist_to_rv.register(pm.Binomial, object)
230+
def convert_dist_to_rv_Binomial(dist, rng):
231+
size = dist.shape.astype(int)[BinomialRV.ndim_supp :]
232+
res = BinomialRV(dist.n, dist.p, size=size, rng=rng)
233+
return res
234+
235+
236+
@_convert_rv_to_dist.register(BinomialRVType, Apply)
237+
def _convert_rv_to_dist_Binomial(op, rv):
238+
params = {"n": rv.inputs[0], "p": rv.inputs[1]}
239+
return pm.Binomial, params
240+
241+
242+
@convert_dist_to_rv.register(pm.Poisson, object)
243+
def convert_dist_to_rv_Poisson(dist, rng):
244+
size = dist.shape.astype(int)[PoissonRV.ndim_supp :]
245+
res = PoissonRV(dist.mu, size=size, rng=rng)
246+
return res
247+
248+
249+
@_convert_rv_to_dist.register(PoissonRVType, Apply)
250+
def _convert_rv_to_dist_Poisson(op, rv):
251+
params = {"mu": rv.inputs[0]}
252+
return pm.Poisson, params
253+
254+
255+
@convert_dist_to_rv.register(pm.Dirichlet, object)
256+
def convert_dist_to_rv_Dirichlet(dist, rng):
257+
size = dist.shape.astype(int)[DirichletRV.ndim_supp :]
258+
res = DirichletRV(dist.a, size=size, rng=rng)
259+
return res
260+
261+
262+
@_convert_rv_to_dist.register(DirichletRVType, Apply)
263+
def _convert_rv_to_dist_Dirichlet(op, rv):
264+
params = {"a": rv.inputs[0]}
265+
return pm.Dirichlet, params
266+
267+
268+
@convert_dist_to_rv.register(pm.Bernoulli, object)
269+
def convert_dist_to_rv_Bernoulli(dist, rng):
270+
size = dist.shape.astype(int)[BernoulliRV.ndim_supp :]
271+
res = BernoulliRV(dist.p, size=size, rng=rng)
272+
return res
273+
274+
275+
@_convert_rv_to_dist.register(BernoulliRVType, Apply)
276+
def _convert_rv_to_dist_Bernoulli(op, rv):
277+
params = {"p": rv.inputs[0]}
278+
return pm.Bernoulli, params
279+
280+
281+
@convert_dist_to_rv.register(pm.BetaBinomial, object)
282+
def convert_dist_to_rv_BetaBinomial(dist, rng):
283+
size = dist.shape.astype(int)[BetaBinomialRV.ndim_supp :]
284+
res = BetaBinomialRV(dist.n, dist.alpha, dist.beta, size=size, rng=rng)
285+
return res
286+
287+
288+
@_convert_rv_to_dist.register(BetaBinomialRVType, Apply)
289+
def _convert_rv_to_dist_BetaBinomial(op, rv):
290+
params = {"n": rv.inputs[0], "alpha": rv.inputs[1], "beta": rv.inputs[2]}
291+
return pm.BetaBinomial, params
292+
293+
294+
@convert_dist_to_rv.register(pm.Categorical, object)
295+
def convert_dist_to_rv_Categorical(dist, rng):
296+
size = dist.shape.astype(int)[CategoricalRV.ndim_supp :]
297+
res = CategoricalRV(dist.p, size=size, rng=rng)
298+
return res
299+
300+
301+
@_convert_rv_to_dist.register(CategoricalRVType, Apply)
302+
def _convert_rv_to_dist_Categorical(op, rv):
303+
params = {"p": rv.inputs[0]}
304+
return pm.Categorical, params
305+
306+
307+
@convert_dist_to_rv.register(pm.Multinomial, object)
308+
def convert_dist_to_rv_Multinomial(dist, rng):
309+
size = dist.shape.astype(int)[MultinomialRV.ndim_supp :]
310+
res = MultinomialRV(dist.n, dist.p, size=size, rng=rng)
311+
return res
312+
313+
314+
@_convert_rv_to_dist.register(MultinomialRVType, Apply)
315+
def _convert_rv_to_dist_Multinomial(op, rv):
316+
params = {"n": rv.inputs[0], "p": rv.inputs[1]}
317+
return pm.Multinomial, params
318+
319+
200320
# TODO: More RV conversions!
201321

202322

@@ -207,9 +327,17 @@ def pymc3_var_to_rv(pm_var, rand_state=None):
207327
new_rv.name = pm_var.name
208328

209329
if isinstance(pm_var, pm.model.ObservedRV):
210-
obs = tt.as_tensor_variable(pm_var.observations)
330+
obs = pm_var.observations
331+
# For some reason, the observations can be float when the RV's dtype is
332+
# not.
333+
if obs.dtype != pm_var.dtype:
334+
obs = obs.astype(pm_var.dtype)
335+
336+
obs = tt.as_tensor_variable(obs)
337+
211338
if getattr(obs, "cached", False):
212339
obs = obs.clone()
340+
213341
new_rv = observed(obs, new_rv)
214342

215343
# Let's attempt to fix the PyMC3 broadcastable dims "oracle" issue,

0 commit comments

Comments
 (0)