1212
1313from functools import partial
1414
15+ from cachetools import cachedmethod , Cache
16+
1517from unification import Var , var , isvar
1618
1719from google .protobuf .message import Message
3739
3840from .. import meta
3941
42+ tf_metatize_cache = Cache (50 )
43+
4044
4145class MetaOpDefLibrary (object ):
4246
@@ -147,26 +151,17 @@ def get_op_info(cls, opdef):
147151
148152def _metatize_tf_object (obj ):
149153 try :
150- obj = tf .convert_to_tensor (obj )
154+ tf_obj = tf .convert_to_tensor (obj )
151155 except (TypeError , ValueError ):
152156 raise ValueError ("Could not find a TensorFlow MetaSymbol class for {obj}" )
153157
154- if isinstance (obj , tf .Tensor ):
155- try :
156- obj .op
157- except AttributeError :
158- raise AttributeError (
159- f"TensorFlow Operation not available; "
160- "try recreating the object with eager-mode disabled"
161- " (e.g. within `tensorflow.python.eager.context.graph_mode`)"
162- )
163-
164- return _metatize (obj )
158+ return _metatize (tf_obj )
165159
166160
167161def load_dispatcher ():
168162 """Set/override dispatcher to default to TF objects."""
169163
164+ from tensorflow .python .framework .ops import EagerTensor
170165 from tensorflow .python .ops .gen_linalg_ops import _SvdOutput
171166
172167 def _metatize_tf_svd (obj ):
@@ -175,6 +170,16 @@ def _metatize_tf_svd(obj):
175170
176171 _metatize .add ((_SvdOutput ,), _metatize_tf_svd )
177172
173+ def _metatize_tf_eager (obj ):
174+ """Catch eager tensor metatize issues early."""
175+ raise AttributeError (
176+ f"TensorFlow Operation not available; "
177+ "try recreating the object with eager-mode disabled"
178+ " (e.g. within `tensorflow.python.eager.context.graph_mode`)"
179+ )
180+
181+ _metatize .add ((EagerTensor ,), _metatize_tf_eager )
182+
178183 _metatize .add ((object ,), _metatize_tf_object )
179184
180185
@@ -269,16 +274,14 @@ def input_args(self, *args, apply_defaults=True, **kwargs):
269274 def __call__ (self , * args , ** kwargs ):
270275 """Create the meta object(s) resulting from an application of this `OpDef`'s implied `Operation`."""
271276
277+ apply_arguments = self .input_args (* args , ** kwargs )
278+
272279 if not meta ._auto_reification_disabled :
273- op_args , op_args_unreified = meta_reify_iter (args )
274- op_kwargs , op_kwargs_unreified = meta_reify_iter (kwargs )
280+ op_args , op_args_unreified = meta_reify_iter (apply_arguments )
275281 else :
276- op_args , op_args_unreified = args , True
277- op_kwargs , op_kwargs_unreified = kwargs , True
278-
279- apply_arguments = self .input_args (* op_args , ** op_kwargs )
282+ op_args , op_args_unreified = apply_arguments , True
280283
281- if not ( op_args_unreified or op_kwargs_unreified ) :
284+ if not op_args_unreified :
282285
283286 # them into meta objects. Doing so will yield information we
284287 # wouldn't be able to produce otherwise (e.g. shape info).
@@ -289,11 +292,22 @@ def __call__(self, *args, **kwargs):
289292 # the TF-`Operation` inferred values (e.g. shapes, dtypes, etc.)
290293
291294 # We have to use a primitive string or TF will complain.
292- name = apply_arguments .get ("name" , None )
295+ name = op_args .get ("name" , None )
293296 if name is not None :
294- apply_arguments ["name" ] = str (name )
297+ op_args ["name" ] = str (name )
298+
299+ tf_out = self ._apply_func (** op_args )
300+
301+ # Ensure that the original meta objects will result
302+ # from the following `metatize`
303+ tf_metatize_cache .update (
304+ {
305+ k : v
306+ for k , v in zip (op_args .values (), apply_arguments .values ())
307+ if isinstance (k , tf .Tensor )
308+ }
309+ )
295310
296- tf_out = self ._apply_func (** apply_arguments )
297311 res_var = metatize (tf_out )
298312
299313 if "names" in meta ._lvar_defaults_enabled :
@@ -324,7 +338,7 @@ def __call__(self, *args, **kwargs):
324338 node_attr = var ()
325339
326340 if "names" not in meta ._lvar_defaults_enabled :
327- op_name = op_kwargs .get ("name" , self .obj .name )
341+ op_name = kwargs .get ("name" , self .obj .name )
328342 else :
329343 op_name = var ()
330344
@@ -372,6 +386,18 @@ class TFlowMetaNodeDef(TFlowMetaSymbol):
372386 base = NodeDef
373387 __slots__ = ["op" , "name" , "attr" , "_frozen_attr" ]
374388
389+ @classmethod
390+ def _metatize (cls , obj ):
391+ res = super ()._metatize (obj )
392+
393+ if "node_attrs" in meta ._lvar_defaults_enabled :
394+ res .attr = var ()
395+
396+ if "names" in meta ._lvar_defaults_enabled :
397+ res .name = var ()
398+
399+ return res
400+
375401 @classmethod
376402 def _protobuf_convert (cls , k , v ):
377403 """Convert a small subset of protobuf objects.
@@ -473,7 +499,12 @@ def _metatize(cls, obj):
473499 new_args = [
474500 getattr (obj , s ) if s != "inputs" else new_input for s in getattr (cls , "__props__" , [])
475501 ]
476- return cls (* new_args , obj = obj )
502+ res = cls (* new_args , obj = obj )
503+
504+ if meta ._lvar_defaults_enabled .issuperset (["node_attrs" , "names" ]):
505+ res .reset ()
506+
507+ return res
477508
478509 def __init__ (self , op_def , node_def , inputs , outputs = None , obj = None ):
479510 """Create a TensorFlow meta `Operation`.
@@ -654,6 +685,17 @@ class TFlowMetaTensor(TFlowMetaSymbol, MetaVariable):
654685 base = tf .Tensor
655686 __slots__ = ("op" , "value_index" , "dtype" , "_shape" , "_name" )
656687
688+ @classmethod
689+ @cachedmethod (lambda cls : tf_metatize_cache )
690+ def _metatize (cls , obj ):
691+
692+ res = super ()._metatize (obj )
693+
694+ if meta ._lvar_defaults_enabled .issuperset (["node_attrs" , "names" ]):
695+ res .reset ()
696+
697+ return res
698+
657699 def __init__ (self , op , value_index , dtype , obj = None ):
658700 self .op = metatize (op )
659701 # TODO: Sync this value with `op.node_def.attr['dtype']` and/or
@@ -679,13 +721,14 @@ def name(self):
679721 if getattr (self , "_name" , None ):
680722 return self ._name
681723
682- if self .obj is not None and not isinstance (self .obj , Var ):
683- name = self .obj .name
684- elif isinstance (getattr (self .op , "name" , None ), str ) and not isvar (self .value_index ):
724+ if isinstance (getattr (self .op , "name" , None ), str ) and not isvar (self .value_index ):
685725 name = f"{ self .op .name } :{ self .value_index } "
686726 else :
687727 name = var ()
688728
729+ if self .obj is not None and not isinstance (self .obj , Var ):
730+ assert name == self .obj .name
731+
689732 self ._name = name
690733 return self ._name
691734
0 commit comments