Skip to content

Commit ff60cbd

Browse files
Add missing test for OpDef Signature functionality
1 parent ef25173 commit ff60cbd

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,25 +86,26 @@ def make_opdef_sig(cls, opdef):
8686
params[name] = new_param
8787

8888
else:
89-
params = []
90-
for i_name, i_type in input_args:
89+
for i_name, i_type in input_args.items():
9190
p = Parameter(i_name, Parameter.POSITIONAL_OR_KEYWORD, annotation=i_type)
9291
params[i_name] = p
9392

9493
# These are the ambiguities we're attempting to overcome
9594
# with the `tf.raw_ops` functions above.
96-
for a_name, a_type in attrs:
95+
for a_name, a_type in attrs.items():
96+
9797
if a_name == "T":
9898
# This is a type value that will most likely be inferred
9999
# from/by the inputs.
100100
# TODO: We could check for an `allowed_values` attribute.
101101
continue
102+
102103
p = Parameter(
103104
a_name,
104105
Parameter.POSITIONAL_OR_KEYWORD,
105106
# TODO: We could use the `default_value`
106107
# attribute.
107-
default=None,
108+
default=Parameter.empty,
108109
annotation=a_type,
109110
)
110111
params[a_name] = p

tests/tensorflow/test_meta.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TFlowMetaOpDef,
1717
TFlowMetaNodeDef,
1818
TFlowOpName,
19+
MetaOpDefLibrary,
1920
mt)
2021

2122
from tests.tensorflow import run_in_graph_mode
@@ -321,6 +322,47 @@ def test_inputs_remapping():
321322
assert z_mt.inputs[1].obj == z.op.inputs[2]
322323

323324

325+
@pytest.mark.usefixtures("run_with_tensorflow")
326+
def test_opdef_sig():
327+
"""Make sure we can construct an `inspect.Signature` object for a protobuf OpDef when its corresponding function isn't present in `tf.raw_ops`."""
328+
from tensorflow.core.framework import op_def_pb2
329+
330+
custom_opdef_tf = op_def_pb2.OpDef()
331+
custom_opdef_tf.name = "MyOpDef"
332+
333+
arg1_tf = op_def_pb2.OpDef.ArgDef()
334+
arg1_tf.name = "arg1"
335+
arg1_tf.type_attr = "T"
336+
337+
arg2_tf = op_def_pb2.OpDef.ArgDef()
338+
arg2_tf.name = "arg2"
339+
arg2_tf.type_attr = "T"
340+
341+
custom_opdef_tf.input_arg.extend([arg1_tf, arg2_tf])
342+
343+
attr1_tf = op_def_pb2.OpDef.AttrDef()
344+
attr1_tf.name = "T"
345+
attr1_tf.type = "type"
346+
347+
attr2_tf = op_def_pb2.OpDef.AttrDef()
348+
attr2_tf.name = "axis"
349+
attr2_tf.type = "int"
350+
attr2_tf.default_value.i = 1
351+
352+
custom_opdef_tf.attr.extend([attr1_tf, attr2_tf])
353+
354+
opdef_sig = MetaOpDefLibrary.make_opdef_sig(custom_opdef_tf)
355+
356+
import inspect
357+
# These are standard inputs
358+
assert opdef_sig.parameters['arg1'].default == inspect._empty
359+
assert opdef_sig.parameters['arg2'].default == inspect._empty
360+
# These are attributes that are sometimes required by the OpDef
361+
assert opdef_sig.parameters['axis'].default == inspect._empty
362+
# The obligatory tensor name parameter
363+
assert opdef_sig.parameters['name'].default is None
364+
365+
324366
@pytest.mark.usefixtures("run_with_tensorflow")
325367
@run_in_graph_mode
326368
def test_nodedef():

0 commit comments

Comments
 (0)