Skip to content

Commit 278f625

Browse files
Merge pull request #98 from brandonwillard/updates-for-kanren-v1
Updates for miniKanren package v1
2 parents bf9f823 + 352dc7d commit 278f625

File tree

10 files changed

+251
-202
lines changed

10 files changed

+251
-202
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
-e ./
2+
sympy>=1.3
23
coveralls
34
pydocstyle>=3.0.0
45
pytest>=5.0.0

setup.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,15 @@ def get_long_description():
3636
"Theano>=1.0.4",
3737
"tf-nightly-2.0-preview==2.0.0.dev20191002",
3838
"tf-nightly==2.1.0.dev20191003",
39-
"tensorflow-estimator-2.0-preview>=1.14.0.dev2019090801",
39+
"tf-estimator-nightly==2.0.0.dev2019100301",
40+
"tensorflow-estimator-2.0-preview==1.14.0.dev2019090801",
4041
"tfp-nightly==0.9.0.dev20191003",
4142
"multipledispatch>=0.6.0",
42-
"logical-unification>=0.2.2",
43-
"miniKanren>=0.4.0",
44-
"cons>=0.1.3",
43+
"logical-unification>=0.4.3",
44+
"miniKanren>=1.0.1",
45+
"etuples>=0.3.1",
46+
"cons>=0.4.0",
4547
"toolz>=0.9.0",
46-
"sympy>=1.3",
4748
"cachetools",
4849
"pymc3>=3.6",
4950
"pymc4 @ git+https://github.com/pymc-devs/pymc4.git@master#egg=pymc4-0.0.1",
Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
from itertools import tee, chain
2-
from functools import reduce
3-
4-
from toolz import interleave
5-
6-
from kanren.core import goaleval
71
from kanren.facts import Relation
82

93
from unification import unify, reify, Var
@@ -34,58 +28,3 @@ def concat_goal(S):
3428
yield S
3529

3630
return concat_goal
37-
38-
39-
def ldisj_seq(goals):
40-
"""Produce a goal that returns the appended state stream from all successful goal arguments.
41-
42-
In other words, it behaves like logical disjunction/OR for goals.
43-
"""
44-
45-
def ldisj_seq_goal(S):
46-
nonlocal goals
47-
48-
goals, _goals = tee(goals)
49-
50-
yield from interleave(goaleval(g)(S) for g in _goals)
51-
52-
return ldisj_seq_goal
53-
54-
55-
def lconj_seq(goals):
56-
"""Produce a goal that returns the appended state stream in which all goals are necessarily successful.
57-
58-
In other words, it behaves like logical conjunction/AND for goals.
59-
"""
60-
61-
def lconj_seq_goal(S):
62-
nonlocal goals
63-
64-
goals, _goals = tee(goals)
65-
66-
g0 = next(iter(_goals), None)
67-
68-
if g0 is None:
69-
return
70-
71-
z0 = goaleval(g0)(S)
72-
73-
yield from reduce(lambda z, g: chain.from_iterable(map(goaleval(g), z)), _goals, z0)
74-
75-
return lconj_seq_goal
76-
77-
78-
def ldisj(*goals):
79-
return ldisj_seq(goals)
80-
81-
82-
def lconj(*goals):
83-
return lconj_seq(goals)
84-
85-
86-
def conde(*goals):
87-
return ldisj_seq(lconj_seq(g) for g in goals)
88-
89-
90-
lall = lconj
91-
lany = ldisj

symbolic_pymc/theano/printing.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,19 @@
1010

1111
from theano import gof
1212

13-
from sympy import Array as SympyArray
14-
from sympy.printing import latex as sympy_latex
13+
try:
14+
from sympy import Array as SympyArray
15+
from sympy.printing import latex as sympy_latex
16+
17+
def latex_print_array(data): # pragma: no cover
18+
return sympy_latex(SympyArray(data))
19+
20+
21+
except ImportError: # pragma: no cover
22+
23+
def latex_print_array(data):
24+
return data
25+
1526

1627
from .opt import FunctionGraph
1728
from .ops import RandomVariable
@@ -60,16 +71,16 @@ def process_param(self, idx, sform, pstate):
6071
The printer state.
6172
6273
"""
63-
return sform
74+
return sform # pragma: no cover
6475

6576
def process(self, output, pstate):
6677
if output in pstate.memo:
6778
return pstate.memo[output]
6879

6980
pprinter = pstate.pprinter
70-
node = output.owner
81+
node = getattr(output, "owner", None)
7182

72-
if node is None or not isinstance(node.op, RandomVariable):
83+
if node is None or not isinstance(node.op, RandomVariable): # pragma: no cover
7384
raise TypeError(
7485
"Function %s cannot represent a variable that is "
7586
"not the result of a RandomVariable operation" % self.name
@@ -78,7 +89,7 @@ def process(self, output, pstate):
7889
op_name = self.name or getattr(node.op, "print_name", None)
7990
op_name = op_name or getattr(node.op, "name", None)
8091

81-
if op_name is None:
92+
if op_name is None: # pragma: no cover
8293
raise ValueError(f"Could not find a name for {node.op}")
8394

8495
# Allow `Op`s to specify their ascii and LaTeX formats (in a tuple/list
@@ -144,7 +155,7 @@ def process(self, output, pstate):
144155

145156
class GenericSubtensorPrinter(object):
146157
def process(self, r, pstate):
147-
if r.owner is None:
158+
if getattr(r, "owner", None) is None: # pragma: no cover
148159
raise TypeError("Can only print Subtensor.")
149160

150161
output_latex = getattr(pstate, "latex", False)
@@ -161,13 +172,13 @@ def process(self, r, pstate):
161172
if isinstance(entry, slice):
162173
s_parts = [""] * 2
163174
if entry.start is not None:
164-
s_parts[0] = entry.start
175+
s_parts[0] = pstate.pprinter.process(inputs.pop())
165176

166177
if entry.stop is not None:
167-
s_parts[1] = entry.stop
178+
s_parts[1] = pstate.pprinter.process(inputs.pop())
168179

169180
if entry.step is not None:
170-
s_parts.append(entry.stop)
181+
s_parts.append(pstate.pprinter.process(inputs.pop()))
171182

172183
sidxs.append(":".join(s_parts))
173184
else:
@@ -215,16 +226,22 @@ def process(cls, output, pstate):
215226
using_latex = getattr(pstate, "latex", False)
216227
# Crude--but effective--means of stopping print-outs for large
217228
# arrays.
218-
constant = isinstance(output, tt.TensorConstant)
229+
constant = isinstance(output, (tt.TensorConstant, theano.scalar.basic.ScalarConstant))
219230
too_large = constant and (output.data.size > cls.max_line_width * cls.max_line_height)
220231

221232
if constant and not too_large:
222233
# Print constants that aren't too large
223234
if using_latex and output.ndim > 0:
224-
out_name = sympy_latex(SympyArray(output.data))
235+
out_name = latex_print_array(output.data)
225236
else:
226237
out_name = str(output.data)
227-
elif isinstance(output, tt.TensorVariable) or constant:
238+
elif (
239+
isinstance(
240+
output,
241+
(tt.TensorVariable, theano.scalar.basic.Scalar, theano.scalar.basic.ScalarVariable),
242+
)
243+
or constant
244+
):
228245
# Process name and shape
229246

230247
# Attempt to get the original variable, in case this is a cloned
@@ -238,7 +255,7 @@ def process(cls, output, pstate):
238255

239256
shape_strings = pstate.preamble_dict.setdefault("shape_strings", OrderedDict())
240257
shape_strings[output] = shape_info
241-
else:
258+
else: # pragma: no cover
242259
raise TypeError(f"Type {type(output)} not handled by variable printer")
243260

244261
pstate.memo[output] = out_name
@@ -268,7 +285,7 @@ def process_variable_name(cls, output, pstate):
268285
_ = [available_names.pop(v.name, None) for v in fgraph.variables]
269286
setattr(pstate, "available_names", available_names)
270287

271-
if output.name:
288+
if getattr(output, "name", None):
272289
# Observed an existing name; remove it.
273290
out_name = output.name
274291
available_names.pop(out_name, None)
@@ -524,11 +541,18 @@ def __call__(self, *args, latex_env="equation", latex_label=None):
524541

525542
# The order here is important!
526543
tt_pprint.printers.insert(
527-
0, (lambda pstate, r: isinstance(r, tt.Variable), VariableWithShapePrinter)
544+
0,
545+
(
546+
lambda pstate, r: isinstance(r, (theano.scalar.basic.Scalar, tt.Variable)),
547+
VariableWithShapePrinter,
548+
),
528549
)
529550
tt_pprint.printers.insert(
530551
0,
531-
(lambda pstate, r: r.owner and isinstance(r.owner.op, RandomVariable), RandomVariablePrinter()),
552+
(
553+
lambda pstate, r: getattr(r, "owner", None) and isinstance(r.owner.op, RandomVariable),
554+
RandomVariablePrinter(),
555+
),
532556
)
533557

534558

@@ -538,9 +562,9 @@ def process(self, output, pstate):
538562
return pstate.memo[output]
539563

540564
pprinter = pstate.pprinter
541-
node = output.owner
565+
node = getattr(output, "owner", None)
542566

543-
if node is None or not isinstance(node.op, Observed):
567+
if node is None or not isinstance(node.op, Observed): # pragma: no cover
544568
raise TypeError(f"Node Op is not of type `Observed`: {node.op}")
545569

546570
val = node.inputs[0]

symbolic_pymc/unify.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55

66
from kanren.term import arguments, operator
77

8-
from unification.more import unify
98
from unification.variable import Var
10-
from unification.core import _reify, _unify, reify
9+
from unification.core import _reify, _unify, reify, unify
1110

1211
from etuples import etuple
1312

tests/test_relations.py

Lines changed: 2 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from unification import var
22

3-
from kanren import eq, run
3+
from kanren import run
44

5-
from symbolic_pymc.relations import lconj, lconj_seq, ldisj, ldisj_seq, conde, concat
5+
from symbolic_pymc.relations import concat
66

77

88
def test_concat():
@@ -11,81 +11,3 @@ def test_concat():
1111
assert not run(0, q, concat("a", "b", "bc"))
1212
assert not run(0, q, concat(1, "b", "bc"))
1313
assert run(0, q, concat(q, "b", "bc")) == (q,)
14-
15-
16-
def test_lconj_basics():
17-
18-
res = list(lconj(eq(1, var("a")), eq(2, var("b")))({}))
19-
assert res == [{var("a"): 1, var("b"): 2}]
20-
21-
res = list(lconj(eq(1, var("a")))({}))
22-
assert res == [{var("a"): 1}]
23-
24-
res = list(lconj_seq([])({}))
25-
assert res == []
26-
27-
res = list(lconj(eq(1, var("a")), eq(2, var("a")))({}))
28-
assert res == []
29-
30-
res = list(lconj(eq(1, 2))({}))
31-
assert res == []
32-
33-
res = list(lconj(eq(1, 1))({}))
34-
assert res == [{}]
35-
36-
37-
def test_ldisj_basics():
38-
39-
res = list(ldisj(eq(1, var("a")))({}))
40-
assert res == [{var("a"): 1}]
41-
42-
res = list(ldisj(eq(1, 2))({}))
43-
assert res == []
44-
45-
res = list(ldisj(eq(1, 1))({}))
46-
assert res == [{}]
47-
48-
res = list(ldisj(eq(1, var("a")), eq(1, var("a")))({}))
49-
assert res == [{var("a"): 1}, {var("a"): 1}]
50-
51-
res = list(ldisj(eq(1, var("a")), eq(2, var("a")))({}))
52-
assert res == [{var("a"): 1}, {var("a"): 2}]
53-
54-
res = list(ldisj_seq([])({}))
55-
assert res == []
56-
57-
58-
def test_conde_basics():
59-
60-
res = list(conde([eq(1, var("a")), eq(2, var("b"))], [eq(1, var("b")), eq(2, var("a"))])({}))
61-
assert res == [{var("a"): 1, var("b"): 2}, {var("b"): 1, var("a"): 2}]
62-
63-
res = list(conde([eq(1, var("a")), eq(2, 1)], [eq(1, var("b")), eq(2, var("a"))])({}))
64-
assert res == [{var("b"): 1, var("a"): 2}]
65-
66-
res = list(
67-
conde(
68-
[eq(1, var("a")), conde([eq(11, var("aa"))], [eq(12, var("ab"))])],
69-
[
70-
eq(1, var("b")),
71-
conde([eq(111, var("ba")), eq(112, var("bb"))], [eq(121, var("bc"))]),
72-
],
73-
)({})
74-
)
75-
assert res == [
76-
{var("a"): 1, var("aa"): 11},
77-
{var("b"): 1, var("ba"): 111, var("bb"): 112},
78-
{var("a"): 1, var("ab"): 12},
79-
{var("b"): 1, var("bc"): 121},
80-
]
81-
82-
res = list(conde([eq(1, 2)], [eq(1, 1)])({}))
83-
assert res == [{}]
84-
85-
assert list(lconj(eq(1, 1))({})) == [{}]
86-
87-
res = list(lconj(conde([eq(1, 2)], [eq(1, 1)]))({}))
88-
assert res == [{}]
89-
90-
res = list(lconj(conde([eq(1, 2)], [eq(1, 1)]), conde([eq(1, 2)], [eq(1, 1)]))({}))
91-
assert res == [{}]

tests/theano/test_kanren.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_kanren():
7474
Y_mt = mt.MvNormalRV(E_y_mt, V_lv, y_size_lv, y_rng_lv, name=y_name_lv)
7575

7676
with variables(Y_mt):
77-
(res,) = run(0, Y_mt, (eq, Y_rv, Y_mt))
77+
(res,) = run(0, Y_mt, eq(Y_rv, Y_mt))
7878
assert res.reify() == Y_rv
7979

8080

tests/theano/test_opt.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from unification import var
55

66
from kanren import eq
7-
from kanren.core import lallgreedy
7+
from kanren.core import lall
88

99
from etuples import etuple, etuplize
1010

@@ -36,13 +36,11 @@ def test_kanren_opt():
3636
assert isinstance(fgraph.outputs[0].owner.op, tt.Dot)
3737

3838
def distributes(in_lv, out_lv):
39-
return (
40-
lallgreedy,
39+
return lall(
4140
# lhs == A * (x + b)
42-
(eq, etuple(mt.dot, var("A"), etuple(mt.add, var("x"), var("b"))), etuplize(in_lv)),
41+
eq(etuple(mt.dot, var("A"), etuple(mt.add, var("x"), var("b"))), etuplize(in_lv)),
4342
# rhs == A * x + A * b
44-
(
45-
eq,
43+
eq(
4644
etuple(
4745
mt.add, etuple(mt.dot, var("A"), var("x")), etuple(mt.dot, var("A"), var("b"))
4846
),

0 commit comments

Comments
 (0)