File tree Expand file tree Collapse file tree 2 files changed +26
-6
lines changed Expand file tree Collapse file tree 2 files changed +26
-6
lines changed Original file line number Diff line number Diff line change @@ -43,23 +43,31 @@ def wrapper(*args, **kwargs):
4343 _unify ._cache .clear ()
4444
4545
46+ _base_unify = unify .dispatch (object , object , dict )
47+
48+
4649def unify_numpy (u , v , s ):
4750 """Handle NumPy arrays in a special way to avoid warnings/exceptions."""
51+ u = walk (u , s )
4852 v = walk (v , s )
53+
54+ if u is v :
55+ return s
4956 if isvar (u ):
5057 return assoc (s , u , v )
5158 if isvar (v ):
5259 return assoc (s , v , u )
53- # Switch the order of comparison so that `v.__eq__` is tried (in case it's
54- # not also a NumPy array, but has logic for such comparisons)
55- if np .array_equal (v , u ):
60+
61+ if isinstance (u , np .ndarray ) or isinstance (v , np .ndarray ):
62+ if np .array_equal (v , u ):
63+ return s
64+ elif u == v :
5665 return s
66+
5767 return _unify (u , v , s )
5868
5969
60- unify .add ((np .ndarray , object , dict ), unify_numpy )
61- unify .add ((object , np .ndarray , dict ), unify_numpy )
62- unify .add ((np .ndarray , np .ndarray , dict ), unify_numpy )
70+ unify .add ((object , object , dict ), unify_numpy )
6371
6472
6573def unify_MetaSymbol (u , v , s ):
Original file line number Diff line number Diff line change @@ -36,3 +36,15 @@ def test_numpy():
3636 s = unify ([1 , var ('a' )], np_array )
3737
3838 assert s is False
39+
40+ s = unify (var ('a' ), 2 , {var ('a' ): np_array })
41+
42+ assert s is False
43+
44+ s = unify (var ('a' ), var ('b' ), {var ('a' ): np_array })
45+
46+ assert s [var ('a' )] is np_array
47+ assert s [var ('b' )] is np_array
48+
49+ s = unify (np_array , np_array )
50+ assert s == {}
You can’t perform that action at this time.
0 commit comments