3737)
3838
3939
40- class MetaOpDefLibrary (op_def_library .OpDefLibrary ):
41- def __init__ (self , * args , ** kwargs ):
42- # This is a lame way to fix the numerous naming inconsistencies between
43- # TF `Operation`s, `OpDef`s, and the corresponding user-level functions.
44- self .lower_op_name_to_raw = {
45- op_name .lower (): op_name
46- for op_name in dir (tf .raw_ops )
47- if callable (getattr (tf .raw_ops , op_name ))
48- }
49- super ().__init__ (* args , ** kwargs )
40+ class MetaOpDefLibrary (object ):
41+
42+ lower_op_name_to_raw = {
43+ op_name .lower (): op_name
44+ for op_name in dir (tf .raw_ops )
45+ if callable (getattr (tf .raw_ops , op_name ))
46+ }
47+ opdef_signatures = {}
48+
49+ @classmethod
50+ def apply_op (cls , * args , ** kwargs ):
51+ return op_def_library .apply_op (* args , ** kwargs )
5052
5153 @classmethod
5254 def make_opdef_sig (cls , opdef , opdef_py_func = None ):
@@ -121,21 +123,22 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None):
121123 )
122124 return opdef_sig , opdef_py_func
123125
124- def add_op (self , opdef ):
125- op_info = self ._ops .get (opdef .name , None )
126- if op_info is None :
127- super ().add_op (opdef )
128- op_info = self ._ops [opdef .name ]
126+ @classmethod
127+ def get_op_info (cls , opdef ):
128+ if isinstance (opdef , str ):
129+ opdef_name = opdef
130+ opdef = op_def_registry .get (opdef_name )
131+ else :
132+ opdef_name = opdef .name
133+
134+ opdef_sig = cls .opdef_signatures .get (opdef_name , None )
135+
136+ if opdef_sig is None and opdef is not None :
129137 opdef_func = getattr (tf .raw_ops , opdef .name , None )
130- opdef_sig , opdef_func = self .make_opdef_sig (op_info .op_def , opdef_func )
131- op_info .opdef_sig = opdef_sig
132- op_info .opdef_func = opdef_func
133- return op_info
138+ opdef_sig = cls .make_opdef_sig (opdef , opdef_func )
139+ cls .opdef_signatures [opdef .name ] = cls .make_opdef_sig (opdef , opdef_func )
134140
135- def get_opinfo (self , opdef ):
136- if isinstance (opdef , str ):
137- opdef = op_def_registry .get_registered_ops ()[opdef ]
138- return self .add_op (opdef )
141+ return opdef_sig
139142
140143
141144op_def_lib = MetaOpDefLibrary ()
@@ -251,7 +254,7 @@ class TFlowMetaOpDef(MetaOp, metaclass=OpDefFactoryType):
251254 >>> from google.protobuf import json_format
252255 >>> print(json_format.MessageToJson(opdef))
253256 - If you want to use an `OpDef` to construct a node, see
254- `op_def_library.OpDefLibrary. apply_op`.
257+ `op_def_library.apply_op`.
255258
256259 """
257260
@@ -260,9 +263,7 @@ class TFlowMetaOpDef(MetaOp, metaclass=OpDefFactoryType):
260263
261264 def __init__ (self , obj = None ):
262265 super ().__init__ (obj = obj )
263- op_info = op_def_lib .add_op (obj )
264- self ._apply_func_sig = op_info .opdef_sig
265- self ._apply_func = op_info .opdef_func
266+ self ._apply_func_sig , self ._apply_func = op_def_lib .get_op_info (obj )
266267
267268 def out_meta_types (self , inputs = None ):
268269 def _convert_outputs (o ):
@@ -411,8 +412,8 @@ def __init__(self, op, name, attr, obj=None):
411412 # We want to limit the attributes we'll consider to those that show
412413 # up in an OpDef function's signature (e.g. ignore info about
413414 # permissible types).
414- opinfo = op_def_lib .get_opinfo (self .op )
415- op_param_names = opinfo . opdef_sig .parameters .keys ()
415+ opdef_sig , _ = op_def_lib .get_op_info (self .op )
416+ op_param_names = opdef_sig .parameters .keys ()
416417
417418 _attr = dict ()
418419 for k , v in attr .items ():
@@ -496,7 +497,7 @@ def __init__(self, op_def, node_def, inputs, outputs=None, obj=None):
496497 super ().__init__ (obj = obj )
497498
498499 if isinstance (op_def , str ):
499- op_def = op_def_registry .get_registered_ops ()[ op_def ]
500+ op_def = op_def_registry .get ( op_def )
500501
501502 self .op_def = metatize (op_def )
502503 self .node_def = metatize (node_def )
@@ -798,7 +799,7 @@ def __call__(self, x):
798799 def find_opdef (cls , name ):
799800 """Attempt to create a meta `OpDef` for a given TF function/`Operation` name."""
800801 raw_op_name = op_def_lib .lower_op_name_to_raw .get (name .lower (), name )
801- op_def = op_def_registry .get_registered_ops ()[ raw_op_name ]
802+ op_def = op_def_registry .get ( raw_op_name )
802803
803804 if op_def is not None :
804805 meta_obj = TFlowMetaOpDef (obj = op_def )
0 commit comments