Skip to content

Commit b07fbda

Browse files
Introduce more efficient replacements for lall, lany and conde
1 parent b504958 commit b07fbda

File tree

4 files changed

+205
-15
lines changed

4 files changed

+205
-15
lines changed
Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
from unification import isvar
1+
from itertools import tee, chain
2+
from functools import reduce
3+
4+
from toolz import interleave
5+
6+
from unification import isvar, reify
27

38
from kanren import eq
4-
from kanren.core import lallgreedy
9+
from kanren.core import goaleval
510
from kanren.facts import Relation
611
from kanren.goals import goalify
712
from kanren.term import term, operator, arguments
@@ -20,13 +25,85 @@
2025
concat = goalify(lambda *args: "".join(args))
2126

2227

28+
def ldisj_seq(goals):
29+
"""Produce a goal that returns the appended state stream from all successful goal arguments.
30+
31+
In other words, it behaves like logical disjunction/OR for goals.
32+
"""
33+
34+
def ldisj_seq_goal(S):
35+
nonlocal goals
36+
37+
goals, _goals = tee(goals)
38+
39+
yield from interleave(goaleval(g)(S) for g in _goals)
40+
41+
return ldisj_seq_goal
42+
43+
44+
def lconj_seq(goals):
45+
"""Produce a goal that returns the appended state stream in which all goals are necessarily successful.
46+
47+
In other words, it behaves like logical conjunction/AND for goals.
48+
"""
49+
50+
def lconj_seq_goal(S):
51+
nonlocal goals
52+
53+
goals, _goals = tee(goals)
54+
55+
g0 = next(iter(_goals), None)
56+
57+
if g0 is None:
58+
return
59+
60+
z0 = goaleval(g0)(S)
61+
62+
yield from reduce(lambda z, g: chain.from_iterable(map(goaleval(g), z)), _goals, z0)
63+
64+
return lconj_seq_goal
65+
66+
67+
def ldisj(*goals):
68+
return ldisj_seq(goals)
69+
70+
71+
def lconj(*goals):
72+
return lconj_seq(goals)
73+
74+
75+
def conde(*goals):
76+
return ldisj_seq(lconj_seq(g) for g in goals)
77+
78+
79+
lall = lconj
80+
lany = ldisj
81+
82+
2383
def buildo(op, args, obj):
24-
if not isvar(obj):
25-
if not isvar(args):
26-
args = etuplize(args, shallow=True)
27-
oop, oargs = operator(obj), arguments(obj)
28-
return lallgreedy(eq(op, oop), eq(args, oargs))
29-
elif isvar(args) or isvar(op):
30-
return conso(op, args, obj)
31-
else:
32-
return eq(obj, term(op, args))
84+
"""Construct a goal that relates an object and its rand + rators decomposition.
85+
86+
This version uses etuples.
87+
88+
"""
89+
90+
def buildo_goal(S):
91+
nonlocal op, args, obj
92+
93+
op, args, obj = reify((op, args, obj), S)
94+
95+
if not isvar(obj):
96+
97+
if not isvar(args):
98+
args = etuplize(args, shallow=True)
99+
100+
oop, oargs = operator(obj), arguments(obj)
101+
102+
yield from lall(eq(op, oop), eq(args, oargs))(S)
103+
104+
elif isvar(args) or isvar(op):
105+
yield from conso(op, args, obj)(S)
106+
else:
107+
yield from eq(obj, term(op, args))(S)
108+
109+
return buildo_goal

symbolic_pymc/relations/graph.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
from kanren import eq
99
from cons.core import ConsPair, ConsNull
10-
from kanren.core import conde, lall
1110
from kanren.goals import conso, fail
1211

12+
from . import conde, lall
13+
1314
from ..etuple import etuplize, ExpressionTuple
1415

1516

@@ -217,7 +218,7 @@ def graph_applyo(
217218
def preprocess_graph(x):
218219
return x
219220

220-
def _gapplyo(s):
221+
def graph_applyo_goal(s):
221222

222223
nonlocal in_graph, out_graph
223224

@@ -256,4 +257,4 @@ def _gapplyo(s):
256257
g = goaleval(g)
257258
yield from g(s)
258259

259-
return _gapplyo
260+
return graph_applyo_goal

tests/test_goals.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from unification import var
2+
3+
from kanren import eq # , run, lall
4+
# from kanren.core import goaleval
5+
6+
from symbolic_pymc.relations import lconj, lconj_seq, ldisj, ldisj_seq, conde
7+
8+
9+
def test_lconj_basics():
10+
11+
res = list(lconj(eq(1, var('a')), eq(2, var('b')))({}))
12+
assert res == [{var('a'): 1, var('b'): 2}]
13+
14+
res = list(lconj(eq(1, var('a')))({}))
15+
assert res == [{var('a'): 1}]
16+
17+
res = list(lconj_seq([])({}))
18+
assert res == []
19+
20+
res = list(lconj(eq(1, var('a')), eq(2, var('a')))({}))
21+
assert res == []
22+
23+
res = list(lconj(eq(1, 2))({}))
24+
assert res == []
25+
26+
res = list(lconj(eq(1, 1))({}))
27+
assert res == [{}]
28+
29+
30+
def test_ldisj_basics():
31+
32+
res = list(ldisj(eq(1, var('a')))({}))
33+
assert res == [{var('a'): 1}]
34+
35+
res = list(ldisj(eq(1, 2))({}))
36+
assert res == []
37+
38+
res = list(ldisj(eq(1, 1))({}))
39+
assert res == [{}]
40+
41+
res = list(ldisj(eq(1, var('a')), eq(1, var('a')))({}))
42+
assert res == [{var('a'): 1}, {var('a'): 1}]
43+
44+
res = list(ldisj(eq(1, var('a')), eq(2, var('a')))({}))
45+
assert res == [{var('a'): 1}, {var('a'): 2}]
46+
47+
res = list(ldisj_seq([])({}))
48+
assert res == []
49+
50+
51+
def test_conde_basics():
52+
53+
res = list(conde([eq(1, var('a')), eq(2, var('b'))],
54+
[eq(1, var('b')), eq(2, var('a'))])({}))
55+
assert res == [{var('a'): 1, var('b'): 2},
56+
{var('b'): 1, var('a'): 2}]
57+
58+
res = list(conde([eq(1, var('a')), eq(2, 1)],
59+
[eq(1, var('b')), eq(2, var('a'))])({}))
60+
assert res == [{var('b'): 1, var('a'): 2}]
61+
62+
res = list(conde([eq(1, var('a')),
63+
conde([eq(11, var('aa'))],
64+
[eq(12, var('ab'))])],
65+
[eq(1, var('b')),
66+
conde([eq(111, var('ba')),
67+
eq(112, var('bb'))],
68+
[eq(121, var('bc'))])])({}))
69+
assert res == [{var('a'): 1, var('aa'): 11},
70+
{var('b'): 1, var('ba'): 111, var('bb'): 112},
71+
{var('a'): 1, var('ab'): 12},
72+
{var('b'): 1, var('bc'): 121}]
73+
74+
res = list(conde([eq(1, 2)], [eq(1, 1)])({}))
75+
assert res == [{}]
76+
77+
assert list(lconj(eq(1, 1))({})) == [{}]
78+
79+
res = list(lconj(conde([eq(1, 2)], [eq(1, 1)]))({}))
80+
assert res == [{}]
81+
82+
res = list(lconj(conde([eq(1, 2)], [eq(1, 1)]),
83+
conde([eq(1, 2)], [eq(1, 1)]))({}))
84+
assert res == [{}]
85+
86+
# def test_short_circuit_lconj():
87+
#
88+
# def one_bad_goal(goal_nums, good_goals=10, _eq=eq):
89+
# for i in goal_nums:
90+
# if i == good_goals:
91+
# def _g(S, i=i):
92+
# print('{} bad'.format(i))
93+
# yield from _eq(1, 2)(S)
94+
#
95+
# else:
96+
# def _g(S, i=i):
97+
# print('{} good'.format(i))
98+
# yield from _eq(1, 1)(S)
99+
#
100+
# yield _g
101+
#
102+
# goal_nums = iter(range(20))
103+
# run(0, var('q'), lall(*one_bad_goal(goal_nums)))
104+
#
105+
# # `kanren`'s `lall` will necessarily exhaust the generator.
106+
# next(goal_nums, None)
107+
#
108+
# goal_nums = iter(range(20))
109+
# run(0, var('q'), lconj(one_bad_goal(goal_nums)))
110+
# next(goal_nums, None)

tests/test_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,10 @@ def test_graph_applyo_reverse():
269269
assert test_res == (
270270
# Expansion of the term's root
271271
etuple(add, 5, 5),
272+
# Expansion in the term's arguments
273+
etuple(mul, etuple(log, etuple(exp, 2)), etuple(log, etuple(exp, 5))),
272274
# Two step expansion at the root
273-
etuple(log, etuple(exp, etuple(add, 5, 5))),
275+
# etuple(log, etuple(exp, etuple(add, 5, 5))),
274276
# Expansion into a sub-term
275277
# etuple(mul, 2, etuple(log, etuple(exp, 5)))
276278
)

0 commit comments

Comments
 (0)