@@ -48,16 +48,14 @@ def __init__(self, *args, **kwargs):
4848 super ().__init__ (* args , ** kwargs )
4949
5050 @classmethod
51- def make_opdef_sig (cls , opdef ):
51+ def make_opdef_sig (cls , opdef , opdef_py_func = None ):
5252 """Create a `Signature` object for an `OpDef`.
5353
5454 Annotations are include so that one can partially verify arguments.
5555 """
5656 input_args = OrderedDict ([(a .name , a .type or a .type_attr ) for a in opdef .input_arg ])
5757 attrs = OrderedDict ([(a .name , a .type ) for a in opdef .attr ])
5858
59- opdef_py_func = getattr (tf .raw_ops , opdef .name , None )
60-
6159 params = OrderedDict ()
6260 if opdef_py_func :
6361 # We assume we're dealing with a function from `tf.raw_ops`.
@@ -86,6 +84,9 @@ def make_opdef_sig(cls, opdef):
8684 params [name ] = new_param
8785
8886 else :
87+ # We're crafting the Operation from a low-level via `apply_op`.
88+ opdef_py_func = partial (op_def_lib .apply_op , opdef .name )
89+
8990 for i_name , i_type in input_args .items ():
9091 p = Parameter (i_name , Parameter .POSITIONAL_OR_KEYWORD , annotation = i_type )
9192 params [i_name ] = p
@@ -117,15 +118,17 @@ def make_opdef_sig(cls, opdef):
117118 opdef_sig = Signature (
118119 params .values (), return_annotation = [(o .name , o .type_attr ) for o in opdef .output_arg ]
119120 )
120- return opdef_sig
121+ return opdef_sig , opdef_py_func
121122
122123 def add_op (self , opdef ):
123124 op_info = self ._ops .get (opdef .name , None )
124125 if op_info is None :
125126 super ().add_op (opdef )
126127 op_info = self ._ops [opdef .name ]
127- opdef_sig = self .make_opdef_sig (op_info .op_def )
128+ opdef_func = getattr (tf .raw_ops , opdef .name , None )
129+ opdef_sig , opdef_func = self .make_opdef_sig (op_info .op_def , opdef_func )
128130 op_info .opdef_sig = opdef_sig
131+ op_info .opdef_func = opdef_func
129132 return op_info
130133
131134 def get_opinfo (self , opdef ):
@@ -239,7 +242,7 @@ class TFlowMetaOpDef(MetaOp, TFlowMetaSymbol):
239242 def __init__ (self , obj = None ):
240243 op_info = op_def_lib .add_op (obj )
241244 self .apply_func_sig = op_info .opdef_sig
242- self .apply_func = partial ( op_def_lib . apply_op , obj . name )
245+ self .apply_func = op_info . opdef_func
243246 super ().__init__ (obj = obj )
244247
245248 def out_meta_types (self , inputs = None ):
0 commit comments