Skip to content

Commit fe370f5

Browse files
Add tracking information and tests to meta_parts_unequal
1 parent 8d4a3e4 commit fe370f5

File tree

2 files changed

+248
-25
lines changed

2 files changed

+248
-25
lines changed

symbolic_pymc/utils.py

Lines changed: 83 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
import numpy as np
22

3-
import symbolic_pymc as sp
3+
from operator import ne, attrgetter, itemgetter
4+
from collections import namedtuple
5+
from collections.abc import Hashable, Sequence, Mapping
6+
7+
from unification import isvar, Var
48

5-
from collections import Hashable
9+
from toolz import compose
10+
11+
import symbolic_pymc as sp
612

713

814
class HashableNDArray(np.ndarray, Hashable):
@@ -31,40 +37,92 @@ def __ne__(self, other):
3137
return NotImplemented
3238

3339

34-
def meta_parts_unequal(x, y, pdb=False): # pragma: no cover
35-
"""Traverse meta objects and return the first pair of elements that are not equal."""
40+
UnequalMetaParts = namedtuple("UnequalMetaParts", ["path", "reason", "objects"])
41+
42+
43+
def meta_diff_seq(x, y, loc, path, is_map=False, **kwargs):
44+
if len(x) != len(y):
45+
return (path, f"{loc} len", (x, y))
46+
else:
47+
for i, (a, b) in enumerate(zip(x, y)):
48+
if is_map:
49+
if a[0] != b[0]:
50+
return (path, "map keys", (x, y))
51+
this_path = compose(itemgetter(a[0]), path)
52+
a, b = a[1], b[1]
53+
else:
54+
this_path = compose(itemgetter(i), path)
55+
56+
z = meta_diff(a, b, path=this_path, **kwargs)
57+
if z is not None:
58+
return z
59+
break
60+
61+
62+
def meta_diff(x, y, pdb=False, ne_fn=ne, cmp_types=True, path=compose()):
63+
"""Traverse meta objects and return information about the first pair of elements that are not equal.
64+
65+
Returns a `UnequalMetaParts` object containing the object path, reason for
66+
being unequal, and the unequal object pair; otherwise, `None`.
67+
"""
3668
res = None
37-
if type(x) != type(y):
38-
print("unequal types")
39-
res = (x, y)
69+
if cmp_types and ne_fn(type(x), type(y)) is True:
70+
res = (path, "types", (x, y))
4071
elif isinstance(x, sp.meta.MetaSymbol):
41-
if x.base != y.base:
42-
print("unequal bases")
43-
res = (x.base, y.base)
72+
if ne_fn(x.base, y.base) is True:
73+
res = (path, "bases", (x.base, y.base))
4474
else:
4575
try:
4676
x_rands = x.rands
4777
y_rands = y.rands
4878
except NotImplementedError:
4979
pass
5080
else:
51-
for a, b in zip(x_rands, y_rands):
52-
z = meta_parts_unequal(a, b, pdb=pdb)
53-
if z is not None:
54-
res = z
55-
break
56-
elif isinstance(x, (tuple, list)):
57-
for a, b in zip(x, y):
58-
z = meta_parts_unequal(a, b, pdb=pdb)
59-
if z is not None:
60-
res = z
61-
break
62-
elif not x == y:
63-
res = (x, y)
81+
82+
path = compose(attrgetter("rands"), path)
83+
84+
res = meta_diff_seq(
85+
x_rands, y_rands, "rands", path, pdb=pdb, ne_fn=ne_fn, cmp_types=cmp_types
86+
)
87+
88+
elif isinstance(x, Mapping) and isinstance(y, Mapping):
89+
90+
x_ = sorted(x.items(), key=itemgetter(0))
91+
y_ = sorted(y.items(), key=itemgetter(0))
92+
93+
res = meta_diff_seq(
94+
x_, y_, "map", path, is_map=True, pdb=pdb, ne_fn=ne_fn, cmp_types=cmp_types
95+
)
96+
97+
elif (
98+
isinstance(x, Sequence)
99+
and isinstance(y, Sequence)
100+
and not isinstance(x, str)
101+
and not isinstance(y, str)
102+
):
103+
104+
res = meta_diff_seq(x, y, "seq", path, pdb=pdb, ne_fn=ne_fn, cmp_types=cmp_types)
105+
106+
elif ne_fn(x, y) is True:
107+
res = (path, "ne_fn", (x, y))
64108

65109
if res is not None:
66-
if pdb:
110+
if pdb: # pragma: no cover
67111
import pdb
68112

69113
pdb.set_trace()
70-
return res
114+
return UnequalMetaParts(*res)
115+
116+
117+
def lvar_ignore_ne(x, y):
118+
if (isvar(x) and isvar(y)) or (
119+
isinstance(x, type) and isinstance(y, type) and issubclass(x, Var) and issubclass(y, Var)
120+
):
121+
return False
122+
else:
123+
return ne(x, y)
124+
125+
126+
def eq_lvar(x, y):
127+
"""Perform an equality check that considers all logic variables equal."""
128+
return meta_diff(x, y, ne_fn=lvar_ignore_ne) is None

tests/test_utils.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from unification import var
2+
3+
from symbolic_pymc.meta import MetaSymbol, MetaOp
4+
from symbolic_pymc.utils import meta_diff, eq_lvar
5+
6+
7+
class SomeOp(object):
8+
def __repr__(self):
9+
return "<SomeOp>"
10+
11+
12+
class SomeType(object):
13+
def __init__(self, field1, field2):
14+
self.field1 = field1
15+
self.field2 = field2
16+
17+
def __repr__(self):
18+
return f"SomeType({self.field1}, {self.field2})"
19+
20+
def __str__(self):
21+
return f"SomeType<{self.field1}, {self.field2}>"
22+
23+
24+
class SomeMetaSymbol(MetaSymbol):
25+
__slots__ = ("field1", "field2", "_blah")
26+
base = SomeType
27+
28+
def __init__(self, obj=None):
29+
super().__init__(obj)
30+
self.field1 = 1
31+
self.field2 = 2
32+
self._blah = "a"
33+
34+
35+
class SomeMetaOp(MetaOp):
36+
__slots__ = ()
37+
base = SomeOp
38+
39+
def output_meta_types(self):
40+
return [SomeMetaSymbol]
41+
42+
def __call__(self, *args, **kwargs):
43+
return SomeMetaSymbol(*args, **kwargs)
44+
45+
46+
class SomeOtherMetaSymbol(MetaSymbol):
47+
__slots__ = ("field1", "field2")
48+
base = SomeType
49+
50+
def __init__(self, field1, field2, obj=None):
51+
super().__init__(obj)
52+
self.field1 = field1
53+
self.field2 = field2
54+
55+
56+
class SomeOtherOp(object):
57+
def __repr__(self):
58+
return "<SomeOp>"
59+
60+
61+
class SomeOtherMetaOp(SomeMetaOp):
62+
base = SomeOtherOp
63+
64+
65+
def test_parts_unequal():
66+
s0 = SomeMetaSymbol()
67+
s1 = SomeOtherMetaSymbol(1, 2)
68+
69+
res = meta_diff(s0, s1)
70+
assert res.reason == "types"
71+
assert res.path(s0) is s0
72+
assert res.objects == (s0, s1)
73+
74+
res = meta_diff(s0, s1, cmp_types=False)
75+
assert res is None
76+
77+
s2 = SomeOtherMetaSymbol(1, 3)
78+
res = meta_diff(s0, s2, cmp_types=False)
79+
assert res.path(s2) == 3
80+
assert res.path(s1) == 2
81+
assert res.reason == "ne_fn"
82+
assert res.objects == (2, 3)
83+
84+
res = meta_diff(SomeMetaOp(), SomeMetaOp())
85+
assert res is None
86+
87+
op1 = SomeMetaOp()
88+
op2 = SomeOtherMetaOp()
89+
res = meta_diff(op1, op2, cmp_types=False)
90+
assert res.path(op1) is op1
91+
assert res.reason == "bases"
92+
assert res.objects == (op1.base, op2.base)
93+
94+
a = SomeOtherMetaSymbol(1, [2, SomeOtherMetaSymbol(3, 4)])
95+
b = SomeOtherMetaSymbol(1, [2, SomeOtherMetaSymbol(3, 5)])
96+
res = meta_diff(a, b)
97+
98+
assert res.path(a) == 4
99+
assert res.path(b) == 5
100+
assert res.reason == "ne_fn"
101+
assert res.objects == (4, 5)
102+
103+
a = SomeOtherMetaSymbol(1, [2, SomeOtherMetaSymbol(3, 4)])
104+
b = SomeOtherMetaSymbol(1, [2, SomeOtherMetaSymbol(3, 4)])
105+
res = meta_diff(a, b)
106+
assert res is None
107+
108+
a = SomeOtherMetaSymbol(1, [2, SomeOtherMetaSymbol(3, 4), 5])
109+
b = SomeOtherMetaSymbol(1, [2, SomeOtherMetaSymbol(3, 4)])
110+
res = meta_diff(a, b)
111+
assert res is not None
112+
assert res.reason == "seq len"
113+
114+
a = SomeOtherMetaSymbol(1, ["a", "b"])
115+
b = SomeOtherMetaSymbol(1, 2)
116+
res = meta_diff(a, b, cmp_types=False)
117+
assert res is not None
118+
assert res.reason == "ne_fn"
119+
120+
a = SomeOtherMetaSymbol(1, ["a", "b"])
121+
b = SomeOtherMetaSymbol(1, "ab")
122+
res = meta_diff(a, b, cmp_types=False)
123+
assert res is not None
124+
125+
a = SomeOtherMetaSymbol(1, {"a": 1, "b": 2})
126+
b = SomeOtherMetaSymbol(1, {"b": 2, "a": 1})
127+
res = meta_diff(a, b)
128+
assert res is None
129+
130+
a = SomeOtherMetaSymbol(1, {"a": 1, "b": 2})
131+
b = SomeOtherMetaSymbol(1, {"b": 3, "a": 1})
132+
res = meta_diff(a, b)
133+
assert res.reason == "ne_fn"
134+
assert res.objects == (2, 3)
135+
assert res.path(a) == 2
136+
assert res.path(b) == 3
137+
138+
a = SomeOtherMetaSymbol(1, {"a": 1, "b": 2})
139+
b = SomeOtherMetaSymbol(1, {"a": 1, "c": 2})
140+
res = meta_diff(a, b)
141+
assert res.reason == "map keys"
142+
assert res.path(a) == {"a": 1, "b": 2}
143+
assert res.objects == ([("a", 1), ("b", 2)], [("a", 1), ("c", 2)])
144+
145+
146+
def test_eq_lvar():
147+
a = SomeOtherMetaSymbol(1, [2, SomeOtherMetaSymbol(3, 4)])
148+
b = SomeOtherMetaSymbol(1, [2, SomeOtherMetaSymbol(3, 4)])
149+
assert eq_lvar(a, b) is True
150+
151+
a = SomeOtherMetaSymbol(1, [2, SomeOtherMetaSymbol(3, 4)])
152+
b = SomeOtherMetaSymbol(1, [2, var()])
153+
assert eq_lvar(a, b) is False
154+
155+
a = SomeOtherMetaSymbol(1, [2, var()])
156+
b = SomeOtherMetaSymbol(1, [2, var()])
157+
assert eq_lvar(a, b) is True
158+
159+
a = SomeOtherMetaSymbol(1, [2, {"a": var()}])
160+
b = SomeOtherMetaSymbol(1, [2, {"a": var()}])
161+
assert eq_lvar(a, b) is True
162+
163+
a = SomeOtherMetaSymbol(1, [3, var()])
164+
b = SomeOtherMetaSymbol(1, [2, var()])
165+
assert eq_lvar(a, b) is False

0 commit comments

Comments
 (0)