Skip to content

Commit 8cf3e81

Browse files
Use the correct MetaOpDef call signatures, function combination
1 parent 1e286a1 commit 8cf3e81

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,14 @@ def __init__(self, *args, **kwargs):
4848
super().__init__(*args, **kwargs)
4949

5050
@classmethod
51-
def make_opdef_sig(cls, opdef):
51+
def make_opdef_sig(cls, opdef, opdef_py_func=None):
5252
"""Create a `Signature` object for an `OpDef`.
5353
5454
Annotations are include so that one can partially verify arguments.
5555
"""
5656
input_args = OrderedDict([(a.name, a.type or a.type_attr) for a in opdef.input_arg])
5757
attrs = OrderedDict([(a.name, a.type) for a in opdef.attr])
5858

59-
opdef_py_func = getattr(tf.raw_ops, opdef.name, None)
60-
6159
params = OrderedDict()
6260
if opdef_py_func:
6361
# We assume we're dealing with a function from `tf.raw_ops`.
@@ -86,6 +84,9 @@ def make_opdef_sig(cls, opdef):
8684
params[name] = new_param
8785

8886
else:
87+
# We're crafting the Operation from a low-level via `apply_op`.
88+
opdef_py_func = partial(op_def_lib.apply_op, opdef.name)
89+
8990
for i_name, i_type in input_args.items():
9091
p = Parameter(i_name, Parameter.POSITIONAL_OR_KEYWORD, annotation=i_type)
9192
params[i_name] = p
@@ -117,15 +118,17 @@ def make_opdef_sig(cls, opdef):
117118
opdef_sig = Signature(
118119
params.values(), return_annotation=[(o.name, o.type_attr) for o in opdef.output_arg]
119120
)
120-
return opdef_sig
121+
return opdef_sig, opdef_py_func
121122

122123
def add_op(self, opdef):
123124
op_info = self._ops.get(opdef.name, None)
124125
if op_info is None:
125126
super().add_op(opdef)
126127
op_info = self._ops[opdef.name]
127-
opdef_sig = self.make_opdef_sig(op_info.op_def)
128+
opdef_func = getattr(tf.raw_ops, opdef.name, None)
129+
opdef_sig, opdef_func = self.make_opdef_sig(op_info.op_def, opdef_func)
128130
op_info.opdef_sig = opdef_sig
131+
op_info.opdef_func = opdef_func
129132
return op_info
130133

131134
def get_opinfo(self, opdef):
@@ -239,7 +242,7 @@ class TFlowMetaOpDef(MetaOp, TFlowMetaSymbol):
239242
def __init__(self, obj=None):
240243
op_info = op_def_lib.add_op(obj)
241244
self.apply_func_sig = op_info.opdef_sig
242-
self.apply_func = partial(op_def_lib.apply_op, obj.name)
245+
self.apply_func = op_info.opdef_func
243246
super().__init__(obj=obj)
244247

245248
def out_meta_types(self, inputs=None):

tests/tensorflow/test_meta.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def test_meta_create():
167167
with pytest.raises(TypeError):
168168
TFlowMetaTensor('float64', 'Add', name='q__')
169169

170+
170171
@pytest.mark.usefixtures("run_with_tensorflow")
171172
@run_in_graph_mode
172173
def test_meta_Op():
@@ -196,7 +197,6 @@ def test_meta_Op():
196197
assert MetaSymbol.is_meta(test_op.outputs[0])
197198

198199

199-
200200
@pytest.mark.usefixtures("run_with_tensorflow")
201201
def test_meta_lvars():
202202
"""Make sure we can use lvars as values."""
@@ -381,7 +381,7 @@ def test_opdef_sig():
381381

382382
custom_opdef_tf.attr.extend([attr1_tf, attr2_tf])
383383

384-
opdef_sig = MetaOpDefLibrary.make_opdef_sig(custom_opdef_tf)
384+
opdef_sig, opdef_func = MetaOpDefLibrary.make_opdef_sig(custom_opdef_tf)
385385

386386
import inspect
387387
# These are standard inputs
@@ -431,3 +431,13 @@ class CustomClass(object):
431431

432432
with pytest.raises(ValueError):
433433
mt(CustomClass())
434+
435+
436+
@pytest.mark.usefixtures("run_with_tensorflow")
437+
@run_in_graph_mode
438+
def test_opdef_func():
439+
sum_mt = mt.Sum([[1, 2]], [1])
440+
sum_tf = sum_mt.reify()
441+
442+
with tf.compat.v1.Session() as sess:
443+
assert sum_tf.eval() == np.r_[3]

0 commit comments

Comments
 (0)