Skip to content

Commit 34d54ed

Browse files
Merge pull request #59 from brandonwillard/efficient-mk-stream-ops
Efficient miniKanren stream functions
2 parents b504958 + f51fa23 commit 34d54ed

File tree

10 files changed

+496
-35
lines changed

10 files changed

+496
-35
lines changed

setup.cfg

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ filterwarnings =
1313

1414
[coverage:report]
1515
exclude_lines =
16-
pragma: no cover
16+
pragma: no cover
17+
18+
raise NotImplementedError

symbolic_pymc/constraints.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import weakref
2+
3+
from abc import ABC, abstractmethod
4+
from types import MappingProxyType
5+
from collections import OrderedDict
6+
7+
from unification import unify, reify, Var
8+
from unification.core import _unify, _reify
9+
10+
11+
class KanrenConstraintStore(ABC):
12+
"""A class that enforces constraints between logic variables in a miniKanren state."""
13+
14+
__slots__ = ("mappings",)
15+
16+
def __init__(self, mappings=None):
17+
# Mappings between logic variables and their constraint values
18+
# (e.g. the values against which they cannot be unified).
19+
self.mappings = mappings if mappings is not None else dict()
20+
# TODO: We can't use this until `Var` is a factory returning unique
21+
# objects/references for a given `Var.token` value.
22+
# self.mappings = weakref.WeakKeyDictionary(mappings)
23+
24+
@abstractmethod
25+
def pre_check(self, state, key=None, value=None):
26+
"""Check a key-value pair before they're added to a KanrenState."""
27+
raise NotImplementedError()
28+
29+
@abstractmethod
30+
def post_check(self, new_state, key=None, value=None, old_state=None):
31+
"""Check a key-value pair after they're added to a KanrenState."""
32+
raise NotImplementedError()
33+
34+
@abstractmethod
35+
def update(self, *args, **kwargs):
36+
"""Add a new constraint."""
37+
raise NotImplementedError()
38+
39+
@abstractmethod
40+
def constraints_str(self, var):
41+
"""Print the constraints on a logic variable."""
42+
raise NotImplementedError()
43+
44+
45+
class KanrenState(dict):
46+
"""A miniKanren state that holds unifications of logic variables and upholds constraints on logic variables."""
47+
48+
__slots__ = ("constraints", "__weakref__")
49+
50+
def __init__(self, *s, constraints=None):
51+
super().__init__(*s)
52+
self.constraints = OrderedDict(constraints or [])
53+
54+
def pre_checks(self, key, value):
55+
return all(cstore.pre_check(self, key, value) for cstore in self.constraints.values())
56+
57+
def post_checks(self, new_state, key, value):
58+
return all(
59+
cstore.post_check(new_state, key, value, old_state=self)
60+
for cstore in self.constraints.values()
61+
)
62+
63+
def add_constraint(self, constraint):
64+
assert isinstance(constraint, KanrenConstraintStore)
65+
self.constraints[type(constraint)] = constraint
66+
67+
def __eq__(self, other):
68+
if isinstance(other, KanrenState):
69+
return super().__eq__(other)
70+
71+
# When comparing with a plain dict, disregard the constraints.
72+
if isinstance(other, dict):
73+
return super().__eq__(other)
74+
return False
75+
76+
def __repr__(self):
77+
return f"KanrenState({super().__repr__()}, {self.constraints})"
78+
79+
80+
class Disequality(KanrenConstraintStore):
81+
"""A disequality constraint (i.e. two things do not unify)."""
82+
83+
def post_check(self, new_state, key=None, value=None, old_state=None):
84+
# This implementation follows-up every addition to a `KanrenState` with
85+
# a consistency check against all the disequality constraints. It's
86+
# not particularly scalable, but it works for now.
87+
return not any(
88+
any(new_state == unify(lvar, val, new_state) for val in vals)
89+
for lvar, vals in self.mappings.items()
90+
)
91+
92+
def pre_check(self, state, key=None, value=None):
93+
return True
94+
95+
def update(self, key, value):
96+
# In this case, logic variables are mapped to a set of values against
97+
# which they cannot unify.
98+
if key not in self.mappings:
99+
self.mappings[key] = {value}
100+
else:
101+
self.mappings[key].add(value)
102+
103+
def constraints_str(self, var):
104+
if var in self.mappings:
105+
return f"=/= {self.mappings[var]}"
106+
else:
107+
return ""
108+
109+
def __repr__(self):
110+
return ",".join([f"{k} =/= {v}" for k, v in self.mappings.items()])
111+
112+
113+
def unify_KanrenState(u, v, S):
114+
if S.pre_checks(u, v):
115+
s = unify(u, v, MappingProxyType(S))
116+
if s is not False and S.post_checks(s, u, v):
117+
return KanrenState(s, constraints=S.constraints)
118+
119+
return False
120+
121+
122+
unify.add((object, object, KanrenState), unify_KanrenState)
123+
unify.add(
124+
(object, object, MappingProxyType),
125+
lambda u, v, d: unify.dispatch(type(u), type(v), dict)(u, v, d),
126+
)
127+
_unify.add(
128+
(object, object, MappingProxyType),
129+
lambda u, v, d: _unify.dispatch(type(u), type(v), dict)(u, v, d),
130+
)
131+
132+
133+
class ConstrainedVar(Var):
134+
"""A logic variable that tracks its own constraints.
135+
136+
Currently, this is only for display/reification purposes.
137+
138+
"""
139+
140+
__slots__ = ("_id", "token", "S", "var")
141+
142+
def __new__(cls, var, S):
143+
obj = super().__new__(cls, var.token)
144+
obj.S = weakref.ref(S)
145+
obj.var = weakref.ref(var)
146+
return obj
147+
148+
def __repr__(self):
149+
var = self.var()
150+
S = self.S()
151+
if var is not None and S is not None:
152+
u_constraints = ",".join([c.constraints_str(var) for c in S.constraints.values()])
153+
return f"{var}: {{{u_constraints}}}"
154+
155+
156+
def reify_KanrenState(u, S):
157+
u_res = reify(u, MappingProxyType(S))
158+
if isinstance(u_res, Var):
159+
return ConstrainedVar(u_res, S)
160+
else:
161+
return u_res
162+
163+
164+
_reify.add((tuple(p[0] for p in _reify.ordering if p[1] == dict), KanrenState), reify_KanrenState)
165+
_reify.add((object, MappingProxyType), lambda u, s: _reify.dispatch(type(u), dict)(u, s))
166+
167+
168+
def neq(u, v):
169+
"""Construct a disequality goal."""
170+
171+
def neq_goal(S):
172+
if not isinstance(S, KanrenState):
173+
S = KanrenState(S)
174+
175+
diseq_constraint = S.constraints.setdefault(Disequality, Disequality())
176+
177+
diseq_constraint.update(u, v)
178+
179+
if diseq_constraint.post_check(S):
180+
return iter([S])
181+
else:
182+
return iter([])
183+
184+
return neq_goal
Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
from unification import isvar
1+
from itertools import tee, chain
2+
from functools import reduce
3+
4+
from toolz import interleave
5+
6+
from unification import isvar, reify
27

38
from kanren import eq
4-
from kanren.core import lallgreedy
9+
from kanren.core import goaleval
510
from kanren.facts import Relation
611
from kanren.goals import goalify
712
from kanren.term import term, operator, arguments
@@ -20,13 +25,85 @@
2025
concat = goalify(lambda *args: "".join(args))
2126

2227

28+
def ldisj_seq(goals):
29+
"""Produce a goal that returns the appended state stream from all successful goal arguments.
30+
31+
In other words, it behaves like logical disjunction/OR for goals.
32+
"""
33+
34+
def ldisj_seq_goal(S):
35+
nonlocal goals
36+
37+
goals, _goals = tee(goals)
38+
39+
yield from interleave(goaleval(g)(S) for g in _goals)
40+
41+
return ldisj_seq_goal
42+
43+
44+
def lconj_seq(goals):
45+
"""Produce a goal that returns the appended state stream in which all goals are necessarily successful.
46+
47+
In other words, it behaves like logical conjunction/AND for goals.
48+
"""
49+
50+
def lconj_seq_goal(S):
51+
nonlocal goals
52+
53+
goals, _goals = tee(goals)
54+
55+
g0 = next(iter(_goals), None)
56+
57+
if g0 is None:
58+
return
59+
60+
z0 = goaleval(g0)(S)
61+
62+
yield from reduce(lambda z, g: chain.from_iterable(map(goaleval(g), z)), _goals, z0)
63+
64+
return lconj_seq_goal
65+
66+
67+
def ldisj(*goals):
68+
return ldisj_seq(goals)
69+
70+
71+
def lconj(*goals):
72+
return lconj_seq(goals)
73+
74+
75+
def conde(*goals):
76+
return ldisj_seq(lconj_seq(g) for g in goals)
77+
78+
79+
lall = lconj
80+
lany = ldisj
81+
82+
2383
def buildo(op, args, obj):
24-
if not isvar(obj):
25-
if not isvar(args):
26-
args = etuplize(args, shallow=True)
27-
oop, oargs = operator(obj), arguments(obj)
28-
return lallgreedy(eq(op, oop), eq(args, oargs))
29-
elif isvar(args) or isvar(op):
30-
return conso(op, args, obj)
31-
else:
32-
return eq(obj, term(op, args))
84+
"""Construct a goal that relates an object and its rand + rators decomposition.
85+
86+
This version uses etuples.
87+
88+
"""
89+
90+
def buildo_goal(S):
91+
nonlocal op, args, obj
92+
93+
op, args, obj = reify((op, args, obj), S)
94+
95+
if not isvar(obj):
96+
97+
if not isvar(args):
98+
args = etuplize(args, shallow=True)
99+
100+
oop, oargs = operator(obj), arguments(obj)
101+
102+
yield from lall(eq(op, oop), eq(args, oargs))(S)
103+
104+
elif isvar(args) or isvar(op):
105+
yield from conso(op, args, obj)(S)
106+
else:
107+
yield from eq(obj, term(op, args))(S)
108+
109+
return buildo_goal

symbolic_pymc/relations/graph.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
from kanren import eq
99
from cons.core import ConsPair, ConsNull
10-
from kanren.core import conde, lall
1110
from kanren.goals import conso, fail
1211

12+
from . import conde, lall
13+
1314
from ..etuple import etuplize, ExpressionTuple
1415

1516

@@ -217,7 +218,7 @@ def graph_applyo(
217218
def preprocess_graph(x):
218219
return x
219220

220-
def _gapplyo(s):
221+
def graph_applyo_goal(s):
221222

222223
nonlocal in_graph, out_graph
223224

@@ -256,4 +257,4 @@ def _gapplyo(s):
256257
g = goaleval(g)
257258
yield from g(s)
258259

259-
return _gapplyo
260+
return graph_applyo_goal

symbolic_pymc/tensorflow/unify.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,20 @@
88
from ..meta import metatize
99
from ..unify import unify_MetaSymbol
1010
from ..etuple import etuplize
11+
from ..constraints import KanrenState
1112

1213
tf_class_abstractions = tuple(c.base for c in TFlowMetaSymbol.base_subclasses())
1314

1415
_unify.add(
15-
(TFlowMetaSymbol, tf_class_abstractions, dict),
16+
(TFlowMetaSymbol, tf_class_abstractions, (KanrenState, dict)),
1617
lambda u, v, s: unify_MetaSymbol(u, metatize(v), s),
1718
)
1819
_unify.add(
19-
(tf_class_abstractions, TFlowMetaSymbol, dict),
20+
(tf_class_abstractions, TFlowMetaSymbol, (KanrenState, dict)),
2021
lambda u, v, s: unify_MetaSymbol(metatize(u), v, s),
2122
)
2223
_unify.add(
23-
(tf_class_abstractions, tf_class_abstractions, dict),
24+
(tf_class_abstractions, tf_class_abstractions, (KanrenState, dict)),
2425
lambda u, v, s: unify_MetaSymbol(metatize(u), metatize(v), s),
2526
)
2627

@@ -30,7 +31,7 @@ def _reify_TFlowClasses(o, s):
3031
return reify(meta_obj, s)
3132

3233

33-
_reify.add((tf_class_abstractions, dict), _reify_TFlowClasses)
34+
_reify.add((tf_class_abstractions, (KanrenState, dict)), _reify_TFlowClasses)
3435

3536
operator.add((tf.Tensor,), lambda x: operator(metatize(x)))
3637

0 commit comments

Comments
 (0)