Skip to content

Commit a83dde4

Browse files
Introduce unification constraints and a disequality implementation
1 parent b07fbda commit a83dde4

File tree

7 files changed

+293
-48
lines changed

7 files changed

+293
-48
lines changed

setup.cfg

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ filterwarnings =
1313

1414
[coverage:report]
1515
exclude_lines =
16-
pragma: no cover
16+
pragma: no cover
17+
18+
raise NotImplementedError

symbolic_pymc/constraints.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from abc import ABC, abstractmethod
2+
from types import MappingProxyType
3+
from collections import OrderedDict, defaultdict
4+
5+
from unification import unify, reify, Var
6+
from unification.core import _unify, _reify
7+
8+
9+
class KanrenConstraintStore(ABC):
10+
"""A class that enforces constraints between logic variables in a miniKanren state."""
11+
12+
@abstractmethod
13+
def pre_check(self, state, key=None, value=None):
14+
"""Check a key-value pair before they're added to a KanrenState."""
15+
raise NotImplementedError()
16+
17+
@abstractmethod
18+
def post_check(self, new_state, key=None, value=None, old_state=None):
19+
"""Check a key-value pair after they're added to a KanrenState."""
20+
raise NotImplementedError()
21+
22+
@abstractmethod
23+
def update(self, *args, **kwargs):
24+
"""Add a new constraint."""
25+
raise NotImplementedError()
26+
27+
@abstractmethod
28+
def constraints_str(self, var):
29+
"""Print the constraints on a logic variable."""
30+
raise NotImplementedError()
31+
32+
33+
class KanrenState(dict):
34+
"""A miniKanren state that holds unifications of logic variables and upholds constraints on logic variables."""
35+
36+
__slots__ = ("constraints",)
37+
38+
def __init__(self, *s, constraints=None):
39+
super().__init__(*s)
40+
self.constraints = OrderedDict(constraints or [])
41+
42+
def pre_checks(self, key, value):
43+
return all(cstore.pre_check(self, key, value) for cstore in self.constraints.values())
44+
45+
def post_checks(self, new_state, key, value):
46+
return all(
47+
cstore.post_check(new_state, key, value, old_state=self)
48+
for cstore in self.constraints.values()
49+
)
50+
51+
def add_constraint(self, constraint):
52+
assert isinstance(constraint, KanrenConstraintStore)
53+
self.constraints[type(constraint)] = constraint
54+
55+
def __eq__(self, other):
56+
if isinstance(other, KanrenState):
57+
return super().__eq__(other)
58+
59+
# When comparing with a plain dict, disregard the constraints.
60+
if isinstance(other, dict):
61+
return super().__eq__(other)
62+
return False
63+
64+
def __repr__(self):
65+
return f"KanrenState({super().__repr__()}, {self.constraints})"
66+
67+
68+
class Disequality(KanrenConstraintStore):
69+
"""A disequality constraint (i.e. two things do not unify)."""
70+
71+
def __init__(self, mappings=None):
72+
# Unallowed mappings
73+
self.mappings = mappings or defaultdict(set)
74+
75+
def post_check(self, new_state, key=None, value=None, old_state=None):
76+
return not any(
77+
any(new_state == unify(lvar, val, new_state) for val in vals)
78+
for lvar, vals in self.mappings.items()
79+
)
80+
81+
def pre_check(self, state, key=None, value=None):
82+
return True
83+
84+
def update(self, key, value):
85+
self.mappings[key].add(value)
86+
87+
def constraints_str(self, var):
88+
if var in self.mappings:
89+
return f"=/= {self.mappings[var]}"
90+
else:
91+
return ""
92+
93+
def __repr__(self):
94+
return ",".join([f"{k} =/= {v}" for k, v in self.mappings.items()])
95+
96+
97+
def unify_KanrenState(u, v, S):
98+
if S.pre_checks(u, v):
99+
s = unify(u, v, MappingProxyType(S))
100+
if s is not False and S.post_checks(s, u, v):
101+
return KanrenState(s, constraints=S.constraints)
102+
103+
return False
104+
105+
106+
unify.add((object, object, KanrenState), unify_KanrenState)
107+
unify.add(
108+
(object, object, MappingProxyType),
109+
lambda u, v, d: unify.dispatch(type(u), type(v), dict)(u, v, d),
110+
)
111+
_unify.add(
112+
(object, object, MappingProxyType),
113+
lambda u, v, d: _unify.dispatch(type(u), type(v), dict)(u, v, d),
114+
)
115+
116+
117+
class ConstrainedVar(Var):
118+
"""A logic variable that tracks its own constraints.
119+
120+
Currently, this is only for display/reification purposes.
121+
122+
"""
123+
124+
def __new__(cls, var, S):
125+
obj = super().__new__(cls, var.token)
126+
obj.S = S
127+
obj.var = var
128+
return obj
129+
130+
def __repr__(self):
131+
u_constraints = ",".join([c.constraints_str(self.var) for c in self.S.constraints.values()])
132+
return f"{self.var}: {{{u_constraints}}}"
133+
134+
135+
def reify_KanrenState(u, S):
136+
u_res = reify(u, MappingProxyType(S))
137+
if isinstance(u_res, Var):
138+
return ConstrainedVar(u_res, S)
139+
else:
140+
return u_res
141+
142+
143+
_reify.add((tuple(p[0] for p in _reify.ordering if p[1] == dict), KanrenState), reify_KanrenState)
144+
_reify.add((object, MappingProxyType), lambda u, s: _reify.dispatch(type(u), dict)(u, s))
145+
146+
147+
def neq(u, v):
148+
"""Construct a disequality goal."""
149+
150+
def neq_goal(S):
151+
if not isinstance(S, KanrenState):
152+
S = KanrenState(S)
153+
154+
diseq_constraint = S.constraints.setdefault(Disequality, Disequality())
155+
156+
diseq_constraint.update(u, v)
157+
158+
if diseq_constraint.post_check(S):
159+
return iter([S])
160+
else:
161+
return iter([])
162+
163+
return neq_goal

symbolic_pymc/tensorflow/unify.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,20 @@
88
from ..meta import metatize
99
from ..unify import unify_MetaSymbol
1010
from ..etuple import etuplize
11+
from ..constraints import KanrenState
1112

1213
tf_class_abstractions = tuple(c.base for c in TFlowMetaSymbol.base_subclasses())
1314

1415
_unify.add(
15-
(TFlowMetaSymbol, tf_class_abstractions, dict),
16+
(TFlowMetaSymbol, tf_class_abstractions, (KanrenState, dict)),
1617
lambda u, v, s: unify_MetaSymbol(u, metatize(v), s),
1718
)
1819
_unify.add(
19-
(tf_class_abstractions, TFlowMetaSymbol, dict),
20+
(tf_class_abstractions, TFlowMetaSymbol, (KanrenState, dict)),
2021
lambda u, v, s: unify_MetaSymbol(metatize(u), v, s),
2122
)
2223
_unify.add(
23-
(tf_class_abstractions, tf_class_abstractions, dict),
24+
(tf_class_abstractions, tf_class_abstractions, (KanrenState, dict)),
2425
lambda u, v, s: unify_MetaSymbol(metatize(u), metatize(v), s),
2526
)
2627

@@ -30,7 +31,7 @@ def _reify_TFlowClasses(o, s):
3031
return reify(meta_obj, s)
3132

3233

33-
_reify.add((tf_class_abstractions, dict), _reify_TFlowClasses)
34+
_reify.add((tf_class_abstractions, (KanrenState, dict)), _reify_TFlowClasses)
3435

3536
operator.add((tf.Tensor,), lambda x: operator(metatize(x)))
3637

symbolic_pymc/theano/unify.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,25 @@
44

55
from unification.core import _reify, _unify, reify
66

7+
from .meta import TheanoMetaSymbol
78
from ..meta import metatize
89
from ..unify import unify_MetaSymbol
910
from ..etuple import ExpressionTuple, etuplize
10-
from .meta import TheanoMetaSymbol
11+
from ..constraints import KanrenState
1112

1213

1314
tt_class_abstractions = tuple(c.base for c in TheanoMetaSymbol.base_subclasses())
1415

1516
_unify.add(
16-
(TheanoMetaSymbol, tt_class_abstractions, dict),
17+
(TheanoMetaSymbol, tt_class_abstractions, (KanrenState, dict)),
1718
lambda u, v, s: unify_MetaSymbol(u, metatize(v), s),
1819
)
1920
_unify.add(
20-
(tt_class_abstractions, TheanoMetaSymbol, dict),
21+
(tt_class_abstractions, TheanoMetaSymbol, (KanrenState, dict)),
2122
lambda u, v, s: unify_MetaSymbol(metatize(u), v, s),
2223
)
2324
_unify.add(
24-
(tt_class_abstractions, tt_class_abstractions, dict),
25+
(tt_class_abstractions, tt_class_abstractions, (KanrenState, dict)),
2526
lambda u, v, s: unify_MetaSymbol(metatize(u), metatize(v), s),
2627
)
2728

@@ -31,7 +32,7 @@ def _reify_TheanoClasses(o, s):
3132
return reify(meta_obj, s)
3233

3334

34-
_reify.add((tt_class_abstractions, dict), _reify_TheanoClasses)
35+
_reify.add((tt_class_abstractions, (KanrenState, dict)), _reify_TheanoClasses)
3536

3637
operator.add((tt.Variable,), lambda x: operator(metatize(x)))
3738

symbolic_pymc/unify.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from .etuple import etuple, ExpressionTuple
1414

15+
from .constraints import KanrenState
16+
1517

1618
class UnificationFailure(Exception):
1719
pass
@@ -67,14 +69,16 @@ def unify_MetaSymbol(u, v, s):
6769
return s
6870

6971

70-
_unify.add((MetaSymbol, MetaSymbol, dict), unify_MetaSymbol)
71-
72-
_tuple__unify = _unify.dispatch(tuple, tuple, dict)
72+
_unify.add((MetaSymbol, MetaSymbol, (KanrenState, dict)), unify_MetaSymbol)
7373

7474
_unify.add(
75-
(ExpressionTuple, (tuple, ExpressionTuple), dict), lambda x, y, s: _tuple__unify(x, y, s)
75+
(ExpressionTuple, (tuple, ExpressionTuple), (KanrenState, dict)),
76+
lambda x, y, s: _unify.dispatch(tuple, tuple, type(s))(x, y, s),
77+
)
78+
_unify.add(
79+
(tuple, ExpressionTuple, (KanrenState, dict)),
80+
lambda x, y, s: _unify.dispatch(tuple, tuple, type(s))(x, y, s),
7681
)
77-
_unify.add((tuple, ExpressionTuple, dict), lambda x, y, s: _tuple__unify(x, y, s))
7882

7983

8084
def _reify_MetaSymbol(o, s):
@@ -106,11 +110,11 @@ def _reify_MetaSymbol(o, s):
106110
return newobj
107111

108112

109-
_reify.add((MetaSymbol, dict), _reify_MetaSymbol)
110-
111-
_tuple__reify = _reify.dispatch(tuple, dict)
113+
_reify.add((MetaSymbol, (KanrenState, dict)), _reify_MetaSymbol)
112114

113-
_reify.add((ExpressionTuple, dict), lambda x, s: _tuple__reify(x, s))
115+
_reify.add(
116+
(ExpressionTuple, (KanrenState, dict)), lambda x, s: _reify.dispatch(tuple, type(s))(x, s)
117+
)
114118

115119

116120
_isvar = isvar.dispatch(object)
@@ -175,7 +179,7 @@ def _term_ExpressionTuple(rand, rators):
175179
term.add((object, ExpressionTuple), _term_ExpressionTuple)
176180

177181

178-
@_reify.register(ExpressionTuple, dict)
182+
@_reify.register(ExpressionTuple, (KanrenState, dict))
179183
def _reify_ExpressionTuple(t, s):
180184
"""When `kanren` reifies `etuple`s, we don't want them to turn into regular `tuple`s.
181185

tests/test_goals.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from unification import var
22

3-
from kanren import eq # , run, lall
4-
# from kanren.core import goaleval
3+
from kanren import eq
54

65
from symbolic_pymc.relations import lconj, lconj_seq, ldisj, ldisj_seq, conde
76

@@ -82,29 +81,3 @@ def test_conde_basics():
8281
res = list(lconj(conde([eq(1, 2)], [eq(1, 1)]),
8382
conde([eq(1, 2)], [eq(1, 1)]))({}))
8483
assert res == [{}]
85-
86-
# def test_short_circuit_lconj():
87-
#
88-
# def one_bad_goal(goal_nums, good_goals=10, _eq=eq):
89-
# for i in goal_nums:
90-
# if i == good_goals:
91-
# def _g(S, i=i):
92-
# print('{} bad'.format(i))
93-
# yield from _eq(1, 2)(S)
94-
#
95-
# else:
96-
# def _g(S, i=i):
97-
# print('{} good'.format(i))
98-
# yield from _eq(1, 1)(S)
99-
#
100-
# yield _g
101-
#
102-
# goal_nums = iter(range(20))
103-
# run(0, var('q'), lall(*one_bad_goal(goal_nums)))
104-
#
105-
# # `kanren`'s `lall` will necessarily exhaust the generator.
106-
# next(goal_nums, None)
107-
#
108-
# goal_nums = iter(range(20))
109-
# run(0, var('q'), lconj(one_bad_goal(goal_nums)))
110-
# next(goal_nums, None)

0 commit comments

Comments
 (0)