Skip to content

Commit f51fa23

Browse files
Add slots and use weak references in constraint objects
1 parent a83dde4 commit f51fa23

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

symbolic_pymc/constraints.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import weakref
2+
13
from abc import ABC, abstractmethod
24
from types import MappingProxyType
3-
from collections import OrderedDict, defaultdict
5+
from collections import OrderedDict
46

57
from unification import unify, reify, Var
68
from unification.core import _unify, _reify
@@ -9,6 +11,16 @@
911
class KanrenConstraintStore(ABC):
1012
"""A class that enforces constraints between logic variables in a miniKanren state."""
1113

14+
__slots__ = ("mappings",)
15+
16+
def __init__(self, mappings=None):
17+
# Mappings between logic variables and their constraint values
18+
# (e.g. the values against which they cannot be unified).
19+
self.mappings = mappings if mappings is not None else dict()
20+
# TODO: We can't use this until `Var` is a factory returning unique
21+
# objects/references for a given `Var.token` value.
22+
# self.mappings = weakref.WeakKeyDictionary(mappings)
23+
1224
@abstractmethod
1325
def pre_check(self, state, key=None, value=None):
1426
"""Check a key-value pair before they're added to a KanrenState."""
@@ -33,7 +45,7 @@ def constraints_str(self, var):
3345
class KanrenState(dict):
3446
"""A miniKanren state that holds unifications of logic variables and upholds constraints on logic variables."""
3547

36-
__slots__ = ("constraints",)
48+
__slots__ = ("constraints", "__weakref__")
3749

3850
def __init__(self, *s, constraints=None):
3951
super().__init__(*s)
@@ -68,11 +80,10 @@ def __repr__(self):
6880
class Disequality(KanrenConstraintStore):
6981
"""A disequality constraint (i.e. two things do not unify)."""
7082

71-
def __init__(self, mappings=None):
72-
# Unallowed mappings
73-
self.mappings = mappings or defaultdict(set)
74-
7583
def post_check(self, new_state, key=None, value=None, old_state=None):
84+
# This implementation follows-up every addition to a `KanrenState` with
85+
# a consistency check against all the disequality constraints. It's
86+
# not particularly scalable, but it works for now.
7687
return not any(
7788
any(new_state == unify(lvar, val, new_state) for val in vals)
7889
for lvar, vals in self.mappings.items()
@@ -82,7 +93,12 @@ def pre_check(self, state, key=None, value=None):
8293
return True
8394

8495
def update(self, key, value):
85-
self.mappings[key].add(value)
96+
# In this case, logic variables are mapped to a set of values against
97+
# which they cannot unify.
98+
if key not in self.mappings:
99+
self.mappings[key] = {value}
100+
else:
101+
self.mappings[key].add(value)
86102

87103
def constraints_str(self, var):
88104
if var in self.mappings:
@@ -121,15 +137,20 @@ class ConstrainedVar(Var):
121137
122138
"""
123139

140+
__slots__ = ("_id", "token", "S", "var")
141+
124142
def __new__(cls, var, S):
125143
obj = super().__new__(cls, var.token)
126-
obj.S = S
127-
obj.var = var
144+
obj.S = weakref.ref(S)
145+
obj.var = weakref.ref(var)
128146
return obj
129147

130148
def __repr__(self):
131-
u_constraints = ",".join([c.constraints_str(self.var) for c in self.S.constraints.values()])
132-
return f"{self.var}: {{{u_constraints}}}"
149+
var = self.var()
150+
S = self.S()
151+
if var is not None and S is not None:
152+
u_constraints = ",".join([c.constraints_str(var) for c in S.constraints.values()])
153+
return f"{var}: {{{u_constraints}}}"
133154

134155

135156
def reify_KanrenState(u, S):

tests/test_kanren.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def test_disequality():
7575
assert isinstance(res[0], KanrenState)
7676
assert res[0].constraints[Disequality].mappings[var('a')] == {1}
7777

78+
res = list(lconj(neq(var('a'), 1), neq(var('a'), 2), neq(var('a'), 1))({}))
79+
assert len(res) == 1
80+
assert isinstance(res[0], KanrenState)
81+
assert res[0].constraints[Disequality].mappings[var('a')] == {1, 2}
82+
7883
res = list(lconj(neq(var('a'), 1), eq(var('a'), 2))({}))
7984
assert len(res) == 1
8085
assert isinstance(res[0], KanrenState)

0 commit comments

Comments
 (0)