Skip to content

Commit 64a65b1

Browse files
Merge pull request #71 from brandonwillard/refactor-graph-relation
Factor out fixed-point logic from graph relation
2 parents d75a754 + 661b501 commit 64a65b1

File tree

5 files changed

+171
-99
lines changed

5 files changed

+171
-99
lines changed

README.md

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,16 @@ import theano.tensor as tt
5151

5252
import pymc3 as pm
5353

54+
from functools import partial
55+
5456
from unification import var
5557

5658
from kanren import run
5759

5860
from symbolic_pymc.theano.printing import tt_pprint
5961
from symbolic_pymc.theano.pymc3 import model_graph
6062

63+
from symbolic_pymc.relations.graph import reduceo
6164
from symbolic_pymc.relations.theano import tt_graph_applyo
6265
from symbolic_pymc.relations.theano.conjugates import conjugate
6366

@@ -92,8 +95,12 @@ fgraph = model_graph(model, output_vars=[Y_rv])
9295

9396
def conjugate_graph(graph):
9497
"""Apply conjugate relations throughout a graph."""
98+
99+
def fixedp_conjugate_applyo(x, y):
100+
return reduceo(partial(tt_graph_applyo, conjugate), x, y)
101+
95102
expr_graph, = run(1, var('q'),
96-
(tt_graph_applyo, conjugate, graph, var('q')))
103+
fixedp_conjugate_applyo(graph, var('q')))
97104

98105
fgraph_opt = expr_graph.eval_obj
99106
fgraph_opt_tt = fgraph_opt.reify()
@@ -135,6 +142,8 @@ import pymc3 as pm
135142
import theano
136143
import theano.tensor as tt
137144

145+
from functools import partial
146+
138147
from unification import var
139148

140149
from kanren import run
@@ -143,7 +152,8 @@ from symbolic_pymc.theano.meta import mt
143152
from symbolic_pymc.theano.pymc3 import model_graph, graph_model
144153
from symbolic_pymc.theano.utils import canonicalize
145154

146-
from symbolic_pymc.relations.theano import tt_graph_applyo, non_obs_graph_applyo
155+
from symbolic_pymc.relations.graph import reduceo
156+
from symbolic_pymc.relations.theano import non_obs_graph_applyo
147157
from symbolic_pymc.relations.theano.distributions import scale_loc_transform
148158

149159

@@ -179,11 +189,12 @@ def reparam_graph(graph):
179189

180190
graph_mt = mt(graph)
181191

192+
def scale_loc_fixedp_applyo(x, y):
193+
return reduceo(partial(non_obs_graph_applyo, scale_loc_transform), x, y)
194+
182195
expr_graph = run(0, var('q'),
183196
# Apply our transforms to unobserved RVs only
184-
non_obs_graph_applyo(
185-
lambda x, y: tt_graph_applyo(scale_loc_transform, x, y),
186-
graph_mt, var('q')))
197+
scale_loc_fixedp_applyo(graph_mt, var('q')))
187198

188199
expr_graph = expr_graph[0]
189200
opt_graph_tt = expr_graph.reify()

symbolic_pymc/relations/graph.py

Lines changed: 77 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import partial
2+
from operator import length_hint
23
from unification import var, isvar
34

45
from unification import reify
@@ -9,7 +10,7 @@
910
from kanren.core import conde, lall
1011
from kanren.goals import conso, fail
1112

12-
from ..etuple import etuplize, etuple, ExpressionTuple
13+
from ..etuple import etuplize, ExpressionTuple
1314

1415

1516
def lapply_anyo(relation, l_in, l_out, null_type=False, skip_op=True):
@@ -109,28 +110,69 @@ def goal(s):
109110
return goal
110111

111112

112-
def reduceo(relation, in_expr, out_expr):
113+
def reduceo(relation, in_term, out_term):
113114
"""Relate a term and the fixed-point of that term under a given relation.
114115
115116
This includes the "identity" relation.
116117
"""
117-
expr_rdcd = var()
118-
return conde(
119-
# The fixed-point is another reduction step out.
120-
[(relation, in_expr, expr_rdcd), (reduceo, relation, expr_rdcd, out_expr)],
121-
# The fixed-point is a single-step reduction.
122-
[(relation, in_expr, out_expr)],
123-
)
118+
119+
def reduceo_goal(s):
120+
121+
nonlocal in_term, out_term
122+
123+
in_term_rf, out_term_rf = reify((in_term, out_term), s)
124+
125+
# The result of reducing the input graph once
126+
term_rdcd = var()
127+
128+
# Are we working "backward" and (potentially) "expanding" a graph
129+
# (e.g. when the relation is a reduction rule)?
130+
is_expanding = isvar(in_term_rf)
131+
132+
# One application of the relation assigned to `term_rdcd`
133+
single_apply_g = (relation, in_term, term_rdcd)
134+
135+
# Assign/equate (unify, really) the result of a single application to
136+
# the "output" term.
137+
single_res_g = eq(term_rdcd, out_term)
138+
139+
# Recurse into applications of the relation (well, produce a goal that
140+
# will do that)
141+
another_apply_g = reduceo(relation, term_rdcd, out_term)
142+
143+
# We want the fixed-point value to show up in the stream output
144+
# *first*, but that requires some checks.
145+
if is_expanding:
146+
# When an un-reduced term is a logic variable (e.g. we're
147+
# "expanding"), we can't go depth first.
148+
# We need to draw the association between (i.e. unify) the reduced
149+
# and expanded terms ASAP, in order to produce finite
150+
# expanded graphs first and yield results.
151+
#
152+
# In other words, there's no fixed-point to produce in this
153+
# situation. Instead, for example, we have to produce an infinite
154+
# stream of terms that have `out_term` as a fixed point.
155+
# g = conde([single_res_g, single_apply_g],
156+
# [another_apply_g, single_apply_g])
157+
g = lall(conde([single_res_g], [another_apply_g]), single_apply_g)
158+
else:
159+
# Run the recursion step first, so that we get the fixed-point as
160+
# the first result
161+
g = lall(single_apply_g, conde([another_apply_g], [single_res_g]))
162+
163+
g = goaleval(g)
164+
yield from g(s)
165+
166+
return reduceo_goal
124167

125168

126169
def graph_applyo(
127170
relation,
128171
in_graph,
129172
out_graph,
130173
preprocess_graph=partial(etuplize, shallow=True, return_bad_args=True),
131-
inside=False,
132174
):
133-
"""Relate the fixed-points of two term-graphs under a given relation.
175+
"""Apply a relation to a graph and its subgraphs.
134176
135177
Parameters
136178
----------
@@ -144,8 +186,6 @@ def graph_applyo(
144186
A unary function that produces an iterable upon which `lapply_anyo`
145187
can be applied in order to traverse a graph's subgraphs. The default
146188
function converts the graph to expression-tuple form.
147-
inside: boolean (optional)
148-
Process the graph or sub-graphs first.
149189
"""
150190

151191
if preprocess_graph in (False, None):
@@ -155,62 +195,39 @@ def preprocess_graph(x):
155195

156196
def _gapplyo(s):
157197

158-
nonlocal in_graph, out_graph, inside
159-
160-
in_rdc = var()
198+
nonlocal in_graph, out_graph
161199

162200
in_graph_rf, out_graph_rf = reify((in_graph, out_graph), s)
163201

164-
expanding = isvar(in_graph_rf)
202+
_gapply = partial(graph_applyo, relation, preprocess_graph=preprocess_graph)
165203

166-
_gapply = partial(
167-
graph_applyo,
168-
relation,
169-
preprocess_graph=preprocess_graph,
170-
inside=inside, # expanding and (True ^ inside)
171-
)
204+
graph_reduce_gl = (relation, in_graph_rf, out_graph_rf)
172205

173-
# This goal reduces the entire graph
174-
graph_reduce_gl = (relation, in_graph_rf, in_rdc)
206+
# We need to get the sub-graphs/children of the input graph/node
207+
if not isvar(in_graph_rf):
208+
in_subgraphs = preprocess_graph(in_graph_rf)
209+
in_subgraphs = None if length_hint(in_subgraphs, 0) == 0 else in_subgraphs
210+
else:
211+
in_subgraphs = in_graph_rf
175212

176-
# This goal reduces children/arguments of the graph
177-
subgraphs_reduce_gl = lapply_anyo(
178-
_gapply,
179-
preprocess_graph(in_graph_rf),
180-
in_rdc,
181-
null_type=etuple() if expanding else False,
182-
)
213+
if not isvar(out_graph_rf):
214+
out_subgraphs = preprocess_graph(out_graph_rf)
215+
out_subgraphs = None if length_hint(out_subgraphs, 0) == 0 else out_subgraphs
216+
else:
217+
out_subgraphs = out_graph_rf
183218

184-
# Take only one step (e.g. reduce the entire graph and/or its
185-
# arguments)
186-
reduce_once_gl = eq(in_rdc, out_graph_rf)
219+
conde_args = ([graph_reduce_gl],)
187220

188-
# Take another reduction step on top of the one(s) we already did
189-
# (i.e. recurse)
190-
reduce_again_gl = _gapply(in_rdc, out_graph_rf)
221+
# This goal reduces sub-graphs/children of the graph.
222+
if in_subgraphs is not None and out_subgraphs is not None:
223+
# We will only include it when there actually are children, or when
224+
# we're dealing with a logic variable (e.g. and "generating"
225+
# children).
226+
subgraphs_reduce_gl = lapply_anyo(_gapply, in_subgraphs, out_subgraphs)
191227

192-
# We want the fixed-point value first, but that requires
193-
# some checks.
194-
if expanding:
195-
# When the un-reduced expression is a logic variable (i.e. we're
196-
# "expanding" expressions), we can't go depth first.
197-
# We need to draw the association between (i.e. unify) the reduced
198-
# and expanded expressions ASAP, in order to produce finite
199-
# expanded graphs first and yield results.
200-
g = conde(
201-
[reduce_once_gl, graph_reduce_gl],
202-
[reduce_again_gl, graph_reduce_gl],
203-
[reduce_once_gl, subgraphs_reduce_gl],
204-
[reduce_again_gl, subgraphs_reduce_gl],
205-
)
206-
else:
207-
# TODO: With an explicit simplification order, could we determine
208-
# whether or not simplifying the sub-expressions or the expression
209-
# itself is more efficient?
210-
g = lall(
211-
conde([graph_reduce_gl], [subgraphs_reduce_gl]),
212-
conde([reduce_again_gl], [reduce_once_gl]),
213-
)
228+
conde_args += ([subgraphs_reduce_gl],)
229+
230+
g = conde(*conde_args)
214231

215232
g = goaleval(g)
216233
yield from g(s)

symbolic_pymc/relations/theano/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ def non_obs_graph_applyo(relation, a, b):
6969
# Deconstruct the observed random variable
7070
(buildo, rv_op_lv, rv_args_lv, obs_rv_lv),
7171
# Apply relation to the RV's inputs
72-
lapply_anyo(
73-
lambda x, y: tt_graph_applyo(relation, x, y), rv_args_lv, new_rv_args_lv, skip_op=False
74-
),
72+
lapply_anyo(partial(tt_graph_applyo, relation), rv_args_lv, new_rv_args_lv, skip_op=False),
7573
# Reconstruct the random variable
7674
(buildo, rv_op_lv, new_rv_args_lv, new_obs_rv_lv),
7775
# Reconstruct the observation

0 commit comments

Comments
 (0)