Skip to content

Commit 1e286a1

Browse files
Fix unify when walked values are Numpy arrays
1 parent 9b81826 commit 1e286a1

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

symbolic_pymc/unify.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,31 @@ def wrapper(*args, **kwargs):
4343
_unify._cache.clear()
4444

4545

46+
_base_unify = unify.dispatch(object, object, dict)
47+
48+
4649
def unify_numpy(u, v, s):
4750
"""Handle NumPy arrays in a special way to avoid warnings/exceptions."""
51+
u = walk(u, s)
4852
v = walk(v, s)
53+
54+
if u is v:
55+
return s
4956
if isvar(u):
5057
return assoc(s, u, v)
5158
if isvar(v):
5259
return assoc(s, v, u)
53-
# Switch the order of comparison so that `v.__eq__` is tried (in case it's
54-
# not also a NumPy array, but has logic for such comparisons)
55-
if np.array_equal(v, u):
60+
61+
if isinstance(u, np.ndarray) or isinstance(v, np.ndarray):
62+
if np.array_equal(v, u):
63+
return s
64+
elif u == v:
5665
return s
66+
5767
return _unify(u, v, s)
5868

5969

60-
unify.add((np.ndarray, object, dict), unify_numpy)
61-
unify.add((object, np.ndarray, dict), unify_numpy)
62-
unify.add((np.ndarray, np.ndarray, dict), unify_numpy)
70+
unify.add((object, object, dict), unify_numpy)
6371

6472

6573
def unify_MetaSymbol(u, v, s):

tests/test_unify.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,15 @@ def test_numpy():
3636
s = unify([1, var('a')], np_array)
3737

3838
assert s is False
39+
40+
s = unify(var('a'), 2, {var('a'): np_array})
41+
42+
assert s is False
43+
44+
s = unify(var('a'), var('b'), {var('a'): np_array})
45+
46+
assert s[var('a')] is np_array
47+
assert s[var('b')] is np_array
48+
49+
s = unify(np_array, np_array)
50+
assert s == {}

0 commit comments

Comments
 (0)