1616 TFlowMetaOpDef ,
1717 TFlowMetaNodeDef ,
1818 TFlowOpName ,
19+ MetaOpDefLibrary ,
1920 mt )
2021
2122from tests .tensorflow import run_in_graph_mode
@@ -321,6 +322,47 @@ def test_inputs_remapping():
321322 assert z_mt .inputs [1 ].obj == z .op .inputs [2 ]
322323
323324
325+ @pytest .mark .usefixtures ("run_with_tensorflow" )
326+ def test_opdef_sig ():
327+ """Make sure we can construct an `inspect.Signature` object for a protobuf OpDef when its corresponding function isn't present in `tf.raw_ops`."""
328+ from tensorflow .core .framework import op_def_pb2
329+
330+ custom_opdef_tf = op_def_pb2 .OpDef ()
331+ custom_opdef_tf .name = "MyOpDef"
332+
333+ arg1_tf = op_def_pb2 .OpDef .ArgDef ()
334+ arg1_tf .name = "arg1"
335+ arg1_tf .type_attr = "T"
336+
337+ arg2_tf = op_def_pb2 .OpDef .ArgDef ()
338+ arg2_tf .name = "arg2"
339+ arg2_tf .type_attr = "T"
340+
341+ custom_opdef_tf .input_arg .extend ([arg1_tf , arg2_tf ])
342+
343+ attr1_tf = op_def_pb2 .OpDef .AttrDef ()
344+ attr1_tf .name = "T"
345+ attr1_tf .type = "type"
346+
347+ attr2_tf = op_def_pb2 .OpDef .AttrDef ()
348+ attr2_tf .name = "axis"
349+ attr2_tf .type = "int"
350+ attr2_tf .default_value .i = 1
351+
352+ custom_opdef_tf .attr .extend ([attr1_tf , attr2_tf ])
353+
354+ opdef_sig = MetaOpDefLibrary .make_opdef_sig (custom_opdef_tf )
355+
356+ import inspect
357+ # These are standard inputs
358+ assert opdef_sig .parameters ['arg1' ].default == inspect ._empty
359+ assert opdef_sig .parameters ['arg2' ].default == inspect ._empty
360+ # These are attributes that are sometimes required by the OpDef
361+ assert opdef_sig .parameters ['axis' ].default == inspect ._empty
362+ # The obligatory tensor name parameter
363+ assert opdef_sig .parameters ['name' ].default is None
364+
365+
324366@pytest .mark .usefixtures ("run_with_tensorflow" )
325367@run_in_graph_mode
326368def test_nodedef ():
@@ -337,3 +379,25 @@ def test_nodedef():
337379 norm_rv = mt .RandomStandardNormal (mean = 0 , stddev = 1 , shape = (1000 ,), dtype = tf .float32 , name = var ())
338380 assert isinstance (norm_rv , TFlowMetaTensor )
339381 assert norm_rv .dtype == tf .float32
382+
383+ # We shouldn't be metatizing all parsed `node_def.attr` values; otherwise,
384+ # we won't be able to reconstruct corresponding meta Ops using their meta
385+ # OpDefs and inputs.
386+ x_test = tf .constant ([1.8 , 2.2 ], dtype = tf .float32 )
387+ y_test = tf .dtypes .cast (x_test , dtype = tf .int32 , name = "y" )
388+ y_test_mt = mt (y_test )
389+
390+ # `ytest_mt.inputs` should have two `.attr` values that are Python
391+ # primitives (i.e. int and bool); these shouldn't get metatized and break
392+ # our ability to reconstruct the object from its rator + rands.
393+ assert y_test_mt == y_test_mt .op .op_def (* y_test_mt .inputs )
394+
395+
396+ @pytest .mark .usefixtures ("run_with_tensorflow" )
397+ @run_in_graph_mode
398+ def test_metatize ():
399+ class CustomClass (object ):
400+ pass
401+
402+ with pytest .raises (ValueError ):
403+ mt (CustomClass ())
0 commit comments