@@ -56,12 +56,11 @@ from functools import partial
5656from unification import var
5757
5858from kanren import run
59+ from kanren.graph import reduceo, walko
5960
6061from symbolic_pymc.theano.printing import tt_pprint
6162from symbolic_pymc.theano.pymc3 import model_graph
6263
63- from symbolic_pymc.relations.graph import reduceo
64- from symbolic_pymc.relations.theano import tt_graph_applyo
6564from symbolic_pymc.relations.theano.conjugates import conjugate
6665
6766theano.config.cxx = ' '
@@ -97,7 +96,7 @@ def conjugate_graph(graph):
9796 """ Apply conjugate relations throughout a graph."""
9897
9998 def fixedp_conjugate_applyo (x , y ):
100- return reduceo(partial(tt_graph_applyo , conjugate), x, y)
99+ return reduceo(partial(walko , conjugate), x, y)
101100
102101 expr_graph, = run(1 , var(' q' ),
103102 fixedp_conjugate_applyo(graph, var(' q' )))
@@ -147,13 +146,13 @@ from functools import partial
147146from unification import var
148147
149148from kanren import run
149+ from kanren.graph import reduceo
150150
151151from symbolic_pymc.theano.meta import mt
152152from symbolic_pymc.theano.pymc3 import model_graph, graph_model
153153from symbolic_pymc.theano.utils import canonicalize
154154
155- from symbolic_pymc.relations.graph import reduceo
156- from symbolic_pymc.relations.theano import non_obs_graph_applyo
155+ from symbolic_pymc.relations.theano import non_obs_walko
157156from symbolic_pymc.relations.theano.distributions import scale_loc_transform
158157
159158
@@ -189,12 +188,12 @@ def reparam_graph(graph):
189188
190189 graph_mt = mt(graph)
191190
192- def scale_loc_fixedp_applyo (x , y ):
193- return reduceo(partial(non_obs_graph_applyo , scale_loc_transform), x, y)
191+ def scale_loc_fixedp_walko (x , y ):
192+ return reduceo(partial(non_obs_walko , scale_loc_transform), x, y)
194193
195194 expr_graph = run(0 , var(' q' ),
196195 # Apply our transforms to unobserved RVs only
197- scale_loc_fixedp_applyo (graph_mt, var(' q' )))
196+ scale_loc_fixedp_walko (graph_mt, var(' q' )))
198197
199198 expr_graph = expr_graph[0 ]
200199 opt_graph_tt = expr_graph.reify()
0 commit comments