Skip to content

Commit 9743bcd

Browse files
Merge pull request #69 from brandonwillard/really-fix-numpy-unify
Fix unify when walked values are Numpy arrays
2 parents 9b81826 + 2c6b4ee commit 9743bcd

File tree

5 files changed

+48
-15
lines changed

5 files changed

+48
-15
lines changed

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pydocstyle>=3.0.0
2-
pytest>=4.2.0
2+
pytest>=5.0.0
33
pytest-cov>=2.6.1
44
pytest-html>=1.20.0
55
pylint>=2.3.1

symbolic_pymc/tensorflow/meta.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,14 @@ def __init__(self, *args, **kwargs):
4848
super().__init__(*args, **kwargs)
4949

5050
@classmethod
51-
def make_opdef_sig(cls, opdef):
51+
def make_opdef_sig(cls, opdef, opdef_py_func=None):
5252
"""Create a `Signature` object for an `OpDef`.
5353
5454
Annotations are include so that one can partially verify arguments.
5555
"""
5656
input_args = OrderedDict([(a.name, a.type or a.type_attr) for a in opdef.input_arg])
5757
attrs = OrderedDict([(a.name, a.type) for a in opdef.attr])
5858

59-
opdef_py_func = getattr(tf.raw_ops, opdef.name, None)
60-
6159
params = OrderedDict()
6260
if opdef_py_func:
6361
# We assume we're dealing with a function from `tf.raw_ops`.
@@ -86,6 +84,9 @@ def make_opdef_sig(cls, opdef):
8684
params[name] = new_param
8785

8886
else:
87+
# We're crafting the Operation from a low-level via `apply_op`.
88+
opdef_py_func = partial(op_def_lib.apply_op, opdef.name)
89+
8990
for i_name, i_type in input_args.items():
9091
p = Parameter(i_name, Parameter.POSITIONAL_OR_KEYWORD, annotation=i_type)
9192
params[i_name] = p
@@ -117,15 +118,17 @@ def make_opdef_sig(cls, opdef):
117118
opdef_sig = Signature(
118119
params.values(), return_annotation=[(o.name, o.type_attr) for o in opdef.output_arg]
119120
)
120-
return opdef_sig
121+
return opdef_sig, opdef_py_func
121122

122123
def add_op(self, opdef):
123124
op_info = self._ops.get(opdef.name, None)
124125
if op_info is None:
125126
super().add_op(opdef)
126127
op_info = self._ops[opdef.name]
127-
opdef_sig = self.make_opdef_sig(op_info.op_def)
128+
opdef_func = getattr(tf.raw_ops, opdef.name, None)
129+
opdef_sig, opdef_func = self.make_opdef_sig(op_info.op_def, opdef_func)
128130
op_info.opdef_sig = opdef_sig
131+
op_info.opdef_func = opdef_func
129132
return op_info
130133

131134
def get_opinfo(self, opdef):
@@ -239,7 +242,7 @@ class TFlowMetaOpDef(MetaOp, TFlowMetaSymbol):
239242
def __init__(self, obj=None):
240243
op_info = op_def_lib.add_op(obj)
241244
self.apply_func_sig = op_info.opdef_sig
242-
self.apply_func = partial(op_def_lib.apply_op, obj.name)
245+
self.apply_func = op_info.opdef_func
243246
super().__init__(obj=obj)
244247

245248
def out_meta_types(self, inputs=None):

symbolic_pymc/unify.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,31 @@ def wrapper(*args, **kwargs):
4343
_unify._cache.clear()
4444

4545

46+
_base_unify = unify.dispatch(object, object, dict)
47+
48+
4649
def unify_numpy(u, v, s):
4750
"""Handle NumPy arrays in a special way to avoid warnings/exceptions."""
51+
u = walk(u, s)
4852
v = walk(v, s)
53+
54+
if u is v:
55+
return s
4956
if isvar(u):
5057
return assoc(s, u, v)
5158
if isvar(v):
5259
return assoc(s, v, u)
53-
# Switch the order of comparison so that `v.__eq__` is tried (in case it's
54-
# not also a NumPy array, but has logic for such comparisons)
55-
if np.array_equal(v, u):
60+
61+
if isinstance(u, np.ndarray) or isinstance(v, np.ndarray):
62+
if np.array_equal(v, u):
63+
return s
64+
elif u == v:
5665
return s
66+
5767
return _unify(u, v, s)
5868

5969

60-
unify.add((np.ndarray, object, dict), unify_numpy)
61-
unify.add((object, np.ndarray, dict), unify_numpy)
62-
unify.add((np.ndarray, np.ndarray, dict), unify_numpy)
70+
unify.add((object, object, dict), unify_numpy)
6371

6472

6573
def unify_MetaSymbol(u, v, s):

tests/tensorflow/test_meta.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def test_meta_create():
167167
with pytest.raises(TypeError):
168168
TFlowMetaTensor('float64', 'Add', name='q__')
169169

170+
170171
@pytest.mark.usefixtures("run_with_tensorflow")
171172
@run_in_graph_mode
172173
def test_meta_Op():
@@ -196,7 +197,6 @@ def test_meta_Op():
196197
assert MetaSymbol.is_meta(test_op.outputs[0])
197198

198199

199-
200200
@pytest.mark.usefixtures("run_with_tensorflow")
201201
def test_meta_lvars():
202202
"""Make sure we can use lvars as values."""
@@ -381,7 +381,7 @@ def test_opdef_sig():
381381

382382
custom_opdef_tf.attr.extend([attr1_tf, attr2_tf])
383383

384-
opdef_sig = MetaOpDefLibrary.make_opdef_sig(custom_opdef_tf)
384+
opdef_sig, opdef_func = MetaOpDefLibrary.make_opdef_sig(custom_opdef_tf)
385385

386386
import inspect
387387
# These are standard inputs
@@ -431,3 +431,13 @@ class CustomClass(object):
431431

432432
with pytest.raises(ValueError):
433433
mt(CustomClass())
434+
435+
436+
@pytest.mark.usefixtures("run_with_tensorflow")
437+
@run_in_graph_mode
438+
def test_opdef_func():
439+
sum_mt = mt.Sum([[1, 2]], [1])
440+
sum_tf = sum_mt.reify()
441+
442+
with tf.compat.v1.Session() as sess:
443+
assert sum_tf.eval() == np.r_[3]

tests/test_unify.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,15 @@ def test_numpy():
3636
s = unify([1, var('a')], np_array)
3737

3838
assert s is False
39+
40+
s = unify(var('a'), 2, {var('a'): np_array})
41+
42+
assert s is False
43+
44+
s = unify(var('a'), var('b'), {var('a'): np_array})
45+
46+
assert s[var('a')] is np_array
47+
assert s[var('b')] is np_array
48+
49+
s = unify(np_array, np_array)
50+
assert s == {}

0 commit comments

Comments
 (0)