Skip to content

Commit ac51c08

Browse files
Rename and lapply_anyo and make it return max goal conjunctions first
The relation/goal `lapply_anyo` has been renamed to `seq_apply_anyo`. It has also been made to return the state with the maximum number of simultaneously successful goals--for the given relation applied across both sequences--first.
1 parent 64a65b1 commit ac51c08

File tree

3 files changed

+91
-47
lines changed

3 files changed

+91
-47
lines changed

symbolic_pymc/relations/graph.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ..etuple import etuplize, ExpressionTuple
1414

1515

16-
def lapply_anyo(relation, l_in, l_out, null_type=False, skip_op=True):
16+
def seq_apply_anyo(relation, l_in, l_out, null_type=False, skip_op=True):
1717
"""Apply a relation to at least one pair of corresponding elements in two sequences.
1818
1919
Parameters
@@ -27,8 +27,10 @@ def lapply_anyo(relation, l_in, l_out, null_type=False, skip_op=True):
2727
will not be applied to the operators (i.e. the cars) of the inputs.
2828
"""
2929

30-
def _lapply_anyo(relation, l_in, l_out, i_any, null_type, skip_cars=False):
31-
def _goal(s):
30+
# This is a customized (based on the initial call arguments) goal
31+
# constructor
32+
def _seq_apply_anyo(relation, l_in, l_out, i_any, null_type, skip_cars=False):
33+
def seq_apply_anyo_sub_goal(s):
3234

3335
nonlocal i_any, null_type
3436

@@ -38,35 +40,57 @@ def _goal(s):
3840
o_car, o_cdr = var(), var()
3941

4042
conde_branches = []
43+
4144
if i_any or (isvar(l_in_rf) and isvar(l_out_rf)):
45+
# Consider terminating the sequences when we've had at least
46+
# one successful goal or when both sequences are logic variables.
4247
conde_branches.append([eq(l_in_rf, null_type), eq(l_in_rf, l_out_rf)])
4348

44-
descend_branch = [
49+
# Extract the CAR and CDR of each argument sequence; this is how we
50+
# iterate through elements of the two sequences.
51+
cons_parts_branch = [
4552
goaleval(conso(i_car, i_cdr, l_in_rf)),
4653
goaleval(conso(o_car, o_cdr, l_out_rf)),
4754
]
4855

49-
conde_branches.append(descend_branch)
56+
conde_branches.append(cons_parts_branch)
5057

51-
conde_2_branches = [
52-
[eq(i_car, o_car), _lapply_anyo(relation, i_cdr, o_cdr, i_any, null_type)]
53-
]
58+
conde_relation_branches = []
59+
60+
relation_branch = None
5461

5562
if not skip_cars:
56-
conde_2_branches.append(
57-
[relation(i_car, o_car), _lapply_anyo(relation, i_cdr, o_cdr, True, null_type)]
58-
)
63+
relation_branch = [
64+
# This case tries the relation continues on.
65+
relation(i_car, o_car),
66+
# In this conde clause, we can tell future calls to
67+
# seq_apply_anyo that we've had at least one successful
68+
# application of the relation (otherwise, this clause
69+
# would fail due to the above goal).
70+
_seq_apply_anyo(relation, i_cdr, o_cdr, True, null_type),
71+
]
72+
73+
conde_relation_branches.append(relation_branch)
74+
75+
base_branch = [
76+
# This is the "base" case; it is used when, for example,
77+
# the given relation isn't satisfied.
78+
eq(i_car, o_car),
79+
_seq_apply_anyo(relation, i_cdr, o_cdr, i_any, null_type),
80+
]
81+
82+
conde_relation_branches.append(base_branch)
5983

60-
descend_branch.append(conde(*conde_2_branches))
84+
cons_parts_branch.append(conde(*conde_relation_branches))
6185

6286
g = conde(*conde_branches)
6387
g = goaleval(g)
6488

6589
yield from g(s)
6690

67-
return _goal
91+
return seq_apply_anyo_sub_goal
6892

69-
def goal(s):
93+
def seq_apply_anyo_init_goal(s):
7094

7195
nonlocal null_type, skip_op
7296

@@ -95,7 +119,7 @@ def goal(s):
95119
else []
96120
)
97121

98-
g = _lapply_anyo(
122+
g = _seq_apply_anyo(
99123
relation,
100124
l_in,
101125
l_out,
@@ -107,7 +131,7 @@ def goal(s):
107131

108132
yield from g(s)
109133

110-
return goal
134+
return seq_apply_anyo_init_goal
111135

112136

113137
def reduceo(relation, in_term, out_term):
@@ -183,7 +207,7 @@ def graph_applyo(
183207
out_graph: object
184208
The graph for which the right-hand side of a binary relation holds.
185209
preprocess_graph: callable (optional)
186-
A unary function that produces an iterable upon which `lapply_anyo`
210+
A unary function that produces an iterable upon which `seq_apply_anyo`
187211
can be applied in order to traverse a graph's subgraphs. The default
188212
function converts the graph to expression-tuple form.
189213
"""
@@ -223,7 +247,7 @@ def _gapplyo(s):
223247
# We will only include it when there actually are children, or when
224248
# we're dealing with a logic variable (e.g. and "generating"
225249
# children).
226-
subgraphs_reduce_gl = lapply_anyo(_gapply, in_subgraphs, out_subgraphs)
250+
subgraphs_reduce_gl = seq_apply_anyo(_gapply, in_subgraphs, out_subgraphs)
227251

228252
conde_args += ([subgraphs_reduce_gl],)
229253

symbolic_pymc/relations/theano/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from kanren.core import lall
77

88
from .linalg import buildo
9-
from ..graph import graph_applyo, lapply_anyo
9+
from ..graph import graph_applyo, seq_apply_anyo
1010
from ...etuple import etuplize, etuple
1111
from ...theano.meta import mt
1212

@@ -69,7 +69,9 @@ def non_obs_graph_applyo(relation, a, b):
6969
# Deconstruct the observed random variable
7070
(buildo, rv_op_lv, rv_args_lv, obs_rv_lv),
7171
# Apply relation to the RV's inputs
72-
lapply_anyo(partial(tt_graph_applyo, relation), rv_args_lv, new_rv_args_lv, skip_op=False),
72+
seq_apply_anyo(
73+
partial(tt_graph_applyo, relation), rv_args_lv, new_rv_args_lv, skip_op=False
74+
),
7375
# Reconstruct the random variable
7476
(buildo, rv_op_lv, new_rv_args_lv, new_obs_rv_lv),
7577
# Reconstruct the observation

tests/test_graph.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from kanren.goals import isinstanceo
1111

1212
from symbolic_pymc.etuple import etuple, ExpressionTuple
13-
from symbolic_pymc.relations.graph import reduceo, lapply_anyo, graph_applyo
13+
from symbolic_pymc.relations.graph import reduceo, seq_apply_anyo, graph_applyo
1414

1515

1616
class OrderedFunction(object):
@@ -88,52 +88,61 @@ def test_reduceo():
8888
assert res[1] == etuple(log, etuple(exp, etuple(log, etuple(exp, 1))))
8989

9090

91-
def test_lapply_anyo_types():
91+
def test_seq_apply_anyo_types():
9292
"""Make sure that `applyo` preserves the types between its arguments."""
9393
q_lv = var()
94-
res = run(1, q_lv, lapply_anyo(lambda x, y: eq(x, y), [1], q_lv))
94+
res = run(1, q_lv, seq_apply_anyo(lambda x, y: eq(x, y), [1], q_lv))
9595
assert res[0] == [1]
96-
res = run(1, q_lv, lapply_anyo(lambda x, y: eq(x, y), (1,), q_lv))
96+
res = run(1, q_lv, seq_apply_anyo(lambda x, y: eq(x, y), (1,), q_lv))
9797
assert res[0] == (1,)
98-
res = run(1, q_lv, lapply_anyo(lambda x, y: eq(x, y), etuple(1,), q_lv, skip_op=False))
98+
res = run(1, q_lv, seq_apply_anyo(lambda x, y: eq(x, y), etuple(1,), q_lv, skip_op=False))
9999
assert res[0] == etuple(1,)
100-
res = run(1, q_lv, lapply_anyo(lambda x, y: eq(x, y), q_lv, (1,)))
100+
res = run(1, q_lv, seq_apply_anyo(lambda x, y: eq(x, y), q_lv, (1,)))
101101
assert res[0] == (1,)
102-
res = run(1, q_lv, lapply_anyo(lambda x, y: eq(x, y), q_lv, [1]))
102+
res = run(1, q_lv, seq_apply_anyo(lambda x, y: eq(x, y), q_lv, [1]))
103103
assert res[0] == [1]
104-
res = run(1, q_lv, lapply_anyo(lambda x, y: eq(x, y), q_lv, etuple(1), skip_op=False))
104+
res = run(1, q_lv, seq_apply_anyo(lambda x, y: eq(x, y), q_lv, etuple(1), skip_op=False))
105105
assert res[0] == etuple(1)
106-
res = run(1, q_lv, lapply_anyo(lambda x, y: eq(x, y), [1, 2], [1, 2]))
106+
res = run(1, q_lv, seq_apply_anyo(lambda x, y: eq(x, y), [1, 2], [1, 2]))
107107
assert len(res) == 1
108-
res = run(1, q_lv, lapply_anyo(lambda x, y: eq(x, y), [1, 2], [1, 3]))
108+
res = run(1, q_lv, seq_apply_anyo(lambda x, y: eq(x, y), [1, 2], [1, 3]))
109109
assert len(res) == 0
110-
res = run(1, q_lv, lapply_anyo(lambda x, y: eq(x, y), [1, 2], (1, 2)))
110+
res = run(1, q_lv, seq_apply_anyo(lambda x, y: eq(x, y), [1, 2], (1, 2)))
111111
assert len(res) == 0
112-
res = run(0, q_lv, lapply_anyo(lambda x, y: eq(y, etuple(mul, 2, x)),
112+
res = run(0, q_lv, seq_apply_anyo(lambda x, y: eq(y, etuple(mul, 2, x)),
113113
etuple(add, 1, 2), q_lv, skip_op=True))
114114
assert len(res) == 3
115115
assert all(r[0] == add for r in res)
116116

117117

118-
def test_lapply_misc():
118+
def test_seq_apply_anyo_misc():
119119
q_lv = var('q')
120120

121-
assert len(run(0, q_lv, lapply_anyo(eq, [1, 2, 3], [1, 2, 3]))) == 1
121+
assert len(run(0, q_lv, seq_apply_anyo(eq, [1, 2, 3], [1, 2, 3]))) == 1
122122

123-
assert len(run(0, q_lv, lapply_anyo(eq, [1, 2, 3], [1, 3, 3]))) == 0
123+
assert len(run(0, q_lv, seq_apply_anyo(eq, [1, 2, 3], [1, 3, 3]))) == 0
124124

125-
assert len(run(4, q_lv, lapply_anyo(math_reduceo, [etuple(mul, 2, var('x'))], q_lv))) == 0
125+
def one_to_threeo(x, y):
126+
return conde([eq(x, 1), eq(y, 3)])
126127

127-
test_res = run(4, q_lv, lapply_anyo(math_reduceo, [etuple(add, 2, 2), 1], q_lv))
128+
res = run(0, q_lv, seq_apply_anyo(one_to_threeo,
129+
[1, 2, 4, 1, 4, 1, 1],
130+
q_lv))
131+
132+
assert res[0] == [3, 2, 4, 3, 4, 3, 3]
133+
134+
assert len(run(4, q_lv, seq_apply_anyo(math_reduceo, [etuple(mul, 2, var('x'))], q_lv))) == 0
135+
136+
test_res = run(4, q_lv, seq_apply_anyo(math_reduceo, [etuple(add, 2, 2), 1], q_lv))
128137
assert test_res == ([etuple(mul, 2, 2), 1],)
129138

130-
test_res = run(4, q_lv, lapply_anyo(math_reduceo, [1, etuple(add, 2, 2)], q_lv))
139+
test_res = run(4, q_lv, seq_apply_anyo(math_reduceo, [1, etuple(add, 2, 2)], q_lv))
131140
assert test_res == ([1, etuple(mul, 2, 2)],)
132141

133-
test_res = run(4, q_lv, lapply_anyo(math_reduceo, q_lv, var('z')))
142+
test_res = run(4, q_lv, seq_apply_anyo(math_reduceo, q_lv, var('z')))
134143
assert all(isinstance(r, list) for r in test_res)
135144

136-
test_res = run(4, q_lv, lapply_anyo(math_reduceo, q_lv, var('z'), tuple()))
145+
test_res = run(4, q_lv, seq_apply_anyo(math_reduceo, q_lv, var('z'), tuple()))
137146
assert all(isinstance(r, tuple) for r in test_res)
138147

139148

@@ -153,11 +162,11 @@ def test_lapply_misc():
153162
([etuple(mul, 2, 1), 5],
154163
[etuple(add, 1, 1), 5],
155164
[etuple(mul, 2, 1), etuple(log, etuple(exp, 5))]))])
156-
def test_lapply_anyo(test_input, test_output):
157-
"""Test `lapply_anyo` with fully ground terms (i.e. no logic variables)."""
165+
def test_seq_apply_anyo(test_input, test_output):
166+
"""Test `seq_apply_anyo` with fully ground terms (i.e. no logic variables)."""
158167
q_lv = var()
159168
test_res = run(0, q_lv,
160-
(lapply_anyo, full_math_reduceo, test_input, q_lv))
169+
(seq_apply_anyo, full_math_reduceo, test_input, q_lv))
161170

162171
assert len(test_res) == len(test_output)
163172

@@ -175,18 +184,18 @@ def test_lapply_anyo(test_input, test_output):
175184
assert test_res == test_output
176185

177186

178-
def test_lapply_anyo_reverse():
179-
"""Test `lapply_anyo` in "reverse" (i.e. specify the reduced form and generate the un-reduced form)."""
187+
def test_seq_apply_anyo_reverse():
188+
"""Test `seq_apply_anyo` in "reverse" (i.e. specify the reduced form and generate the un-reduced form)."""
180189
# Unbounded reverse
181190
q_lv = var()
182191
rev_input = [etuple(mul, 2, 1)]
183-
test_res = run(4, q_lv, (lapply_anyo, math_reduceo, q_lv, rev_input))
192+
test_res = run(4, q_lv, (seq_apply_anyo, math_reduceo, q_lv, rev_input))
184193
assert test_res == ([etuple(add, 1, 1)],
185194
[etuple(log, etuple(exp, etuple(mul, 2, 1)))])
186195

187196
# Guided reverse
188197
test_res = run(4, q_lv,
189-
(lapply_anyo, math_reduceo,
198+
(seq_apply_anyo, math_reduceo,
190199
[etuple(add, q_lv, 1)],
191200
[etuple(mul, 2, 1)]))
192201

@@ -203,6 +212,15 @@ def test_graph_applyo_misc():
203212

204213
assert len(run(0, q_lv, graph_applyo(eq, etuple(), etuple(), preprocess_graph=None))) == 1
205214

215+
def one_to_threeo(x, y):
216+
return conde([eq(x, 1), eq(y, 3)])
217+
218+
res = run(0, q_lv, graph_applyo(one_to_threeo,
219+
[1, [1, 2, 4], 2, [[4, 1, 1]], 1],
220+
q_lv, preprocess_graph=None))
221+
222+
assert res[0] == [3, [3, 2, 4], 2, [[4, 3, 3]], 3]
223+
206224

207225
@pytest.mark.parametrize(
208226
'test_input, test_output',

0 commit comments

Comments
 (0)