@@ -820,15 +820,34 @@ class TFlowMetaOperator(TFlowMetaSymbol, MetaOp):
820820 base = None
821821 __slots__ = ("op_def" , "node_def" , "_apply_func_sig" , "_apply_func" )
822822
823+ @classmethod
824+ def get_metaopdef (cls , name ):
825+ """Obtain a MetaOpDef for a given string name.
826+
827+ This is more flexible because it ignores things like string case
828+ (when the non-`raw_ops` name differs from the TF user-level API).
829+ """
830+ raw_op_name = op_def_lib .lower_op_name_to_raw .get (name .lower (), name )
831+ op_def = op_def_registry .get (raw_op_name )
832+ if op_def is not None :
833+ return TFlowMetaOpDef (obj = op_def )
834+
823835 def __init__ (self , op_def , node_def = None , obj = None ):
824836 assert obj is None
825837 super ().__init__ (None )
826838
827839 self .op_def = op_def
828- if not isvar (self .op_def ):
829- self ._apply_func_sig , self ._apply_func = op_def_lib .get_op_info (self .op_def .obj )
830- else :
840+
841+ if isinstance (self .op_def , str ):
842+ self .op_def = self .get_metaopdef (self .op_def )
843+
844+ if self .op_def is None :
845+ raise ValueError (f"Could not find an OpDef for { op_def } " )
846+
847+ if isvar (self .op_def ):
831848 self ._apply_func_sig , self ._apply_func = None , None
849+ else :
850+ self ._apply_func_sig , self ._apply_func = op_def_lib .get_op_info (self .op_def .obj )
832851
833852 self .node_def = node_def
834853
@@ -1097,18 +1116,6 @@ def __init__(self, namespace=None):
10971116 def __call__ (self , x ):
10981117 return metatize (x )
10991118
1100- @classmethod
1101- def find_operator (cls , name ):
1102- """Attempt to create a meta operator for a given TF function/`Operation` name."""
1103- raw_op_name = op_def_lib .lower_op_name_to_raw .get (name .lower (), name )
1104- op_def = op_def_registry .get (raw_op_name )
1105-
1106- if op_def is not None :
1107- meta_obj = TFlowMetaOperator (TFlowMetaOpDef (obj = op_def ), None )
1108- return meta_obj
1109-
1110- return None
1111-
11121119 def __getattr__ (self , obj ):
11131120
11141121 ns_obj = next ((getattr (ns , obj ) for ns in self .namespaces if hasattr (ns , obj )), None )
@@ -1122,13 +1129,23 @@ def __getattr__(self, obj):
11221129 if ns_obj is None :
11231130 ns_obj = f_back .f_globals .get (obj )
11241131
1125- if isinstance (ns_obj , (types .FunctionType , partial )):
1126- # We assume that the user requested an `Operation`
1127- # constructor/helper. Return the meta `OpDef`, because
1128- # it implements a constructor/helper-like `__call__`.
1129- meta_obj = self .find_operator (obj )
1132+ if isinstance (ns_obj , types .ModuleType ):
1133+ # It's a sub-module, so let's create another
1134+ # `TheanoMetaAccessor` and check within there.
1135+ meta_obj = TFlowMetaAccessor (namespace = ns_obj )
1136+ else :
1137+
1138+ # Check for a an OpDef first
1139+ meta_obj = TFlowMetaOperator .get_metaopdef (obj )
1140+
1141+ if meta_obj is not None :
1142+ # We assume that the user requested an `Operation`
1143+ # constructor/helper. Return the meta `OpDef`, because
1144+ # it implements a constructor/helper-like `__call__`.
1145+ if meta_obj is not None :
1146+ meta_obj = TFlowMetaOperator (meta_obj , None )
11301147
1131- # if meta_obj is None :
1148+ # elif isinstance(ns_obj, (types.FunctionType, partial)) :
11321149 # # It's a function, so let's provide a wrapper that converts
11331150 # # to-and-from theano and meta objects.
11341151 # @wraps(ns_obj)
@@ -1137,19 +1154,12 @@ def __getattr__(self, obj):
11371154 # res = ns_obj(*args, **kwargs)
11381155 # return metatize(res)
11391156
1140- elif isinstance (ns_obj , types .ModuleType ):
1141- # It's a sub-module, so let's create another
1142- # `TheanoMetaAccessor` and check within there.
1143- meta_obj = TFlowMetaAccessor (namespace = ns_obj )
1144- else :
1145-
1146- # Hopefully, it's convertible to a meta object...
1147- meta_obj = metatize (ns_obj )
1148-
1149- if meta_obj is None :
1150- # Last resort
1151- meta_obj = self .find_operator (obj )
1157+ else :
1158+ # Hopefully, it's convertible to a meta object...
1159+ meta_obj = metatize (ns_obj )
11521160
1161+ # Finally, we store the result as a meta namespace attribute, or raise
1162+ # an exception.
11531163 if isinstance (
11541164 meta_obj , (MetaSymbol , MetaSymbolType , TFlowMetaOperator , types .FunctionType )
11551165 ):
0 commit comments