|
16 | 16 | TFlowMetaOpDef, |
17 | 17 | TFlowMetaNodeDef, |
18 | 18 | TFlowOpName, |
| 19 | + MetaOpDefLibrary, |
19 | 20 | mt) |
20 | 21 |
|
21 | 22 | from tests.tensorflow import run_in_graph_mode |
@@ -321,6 +322,47 @@ def test_inputs_remapping(): |
321 | 322 | assert z_mt.inputs[1].obj == z.op.inputs[2] |
322 | 323 |
|
323 | 324 |
|
| 325 | +@pytest.mark.usefixtures("run_with_tensorflow") |
| 326 | +def test_opdef_sig(): |
| 327 | + """Make sure we can construct an `inspect.Signature` object for a protobuf OpDef when its corresponding function isn't present in `tf.raw_ops`.""" |
| 328 | + from tensorflow.core.framework import op_def_pb2 |
| 329 | + |
| 330 | + custom_opdef_tf = op_def_pb2.OpDef() |
| 331 | + custom_opdef_tf.name = "MyOpDef" |
| 332 | + |
| 333 | + arg1_tf = op_def_pb2.OpDef.ArgDef() |
| 334 | + arg1_tf.name = "arg1" |
| 335 | + arg1_tf.type_attr = "T" |
| 336 | + |
| 337 | + arg2_tf = op_def_pb2.OpDef.ArgDef() |
| 338 | + arg2_tf.name = "arg2" |
| 339 | + arg2_tf.type_attr = "T" |
| 340 | + |
| 341 | + custom_opdef_tf.input_arg.extend([arg1_tf, arg2_tf]) |
| 342 | + |
| 343 | + attr1_tf = op_def_pb2.OpDef.AttrDef() |
| 344 | + attr1_tf.name = "T" |
| 345 | + attr1_tf.type = "type" |
| 346 | + |
| 347 | + attr2_tf = op_def_pb2.OpDef.AttrDef() |
| 348 | + attr2_tf.name = "axis" |
| 349 | + attr2_tf.type = "int" |
| 350 | + attr2_tf.default_value.i = 1 |
| 351 | + |
| 352 | + custom_opdef_tf.attr.extend([attr1_tf, attr2_tf]) |
| 353 | + |
| 354 | + opdef_sig = MetaOpDefLibrary.make_opdef_sig(custom_opdef_tf) |
| 355 | + |
| 356 | + import inspect |
| 357 | + # These are standard inputs |
| 358 | + assert opdef_sig.parameters['arg1'].default == inspect._empty |
| 359 | + assert opdef_sig.parameters['arg2'].default == inspect._empty |
| 360 | + # These are attributes that are sometimes required by the OpDef |
| 361 | + assert opdef_sig.parameters['axis'].default == inspect._empty |
| 362 | + # The obligatory tensor name parameter |
| 363 | + assert opdef_sig.parameters['name'].default is None |
| 364 | + |
| 365 | + |
324 | 366 | @pytest.mark.usefixtures("run_with_tensorflow") |
325 | 367 | @run_in_graph_mode |
326 | 368 | def test_nodedef(): |
|
0 commit comments