88
99from inspect import Parameter , Signature
1010
11- from collections import OrderedDict
11+ from collections import OrderedDict , Sequence
1212
1313from functools import partial
1414
3535 metatize ,
3636)
3737
38+ from .. import meta
39+
3840
3941class MetaOpDefLibrary (object ):
4042
@@ -164,6 +166,15 @@ def _metatize_tf_object(obj):
164166
165167def load_dispatcher ():
166168 """Set/override dispatcher to default to TF objects."""
169+
170+ from tensorflow .python .ops .gen_linalg_ops import _SvdOutput
171+
172+ def _metatize_tf_svd (obj ):
173+ """Turn a TensorFlow `Svd` object/tuple into a standard tuple."""
174+ return _metatize (tuple (obj ))
175+
176+ _metatize .add ((_SvdOutput ,), _metatize_tf_svd )
177+
167178 _metatize .add ((object ,), _metatize_tf_object )
168179
169180
@@ -207,6 +218,7 @@ class TFlowMetaOpDef(MetaOp, metaclass=OpDefFactoryType):
207218
208219 >>> from google.protobuf import json_format
209220 >>> print(json_format.MessageToJson(opdef))
221+
210222 - If you want to use an `OpDef` to construct a node, see
211223 `op_def_library.apply_op`.
212224
@@ -220,39 +232,53 @@ def __init__(self, obj=None):
220232 self ._apply_func_sig , self ._apply_func = op_def_lib .get_op_info (obj )
221233
222234 def out_meta_types (self , inputs = None , node_def = None ):
235+ """Return a list of tuples containing object types and corresponding dtypes for the outputs of this OpDef."""
236+
223237 def _convert_outputs (o ):
224- if o .type_attr == "T" and node_def :
238+ if o .type_attr == "T" and hasattr ( node_def , "attr" ) :
225239 return (TFlowMetaTensor , node_def .attr .get ("T" , var ()))
226240 elif o .type_attr == "dtype" and inputs :
227241 return (TFlowMetaTensor , inputs .get ("dtype" , var ()))
228242 else :
229243 return (TFlowMetaTensor , var ())
230244
245+ # TODO: We also have permissible dtype information from objects in the
246+ # array `self.obj.attr` under the field `allowed_values`.
247+
231248 out_meta_types = tuple (_convert_outputs (o ) for o in self .obj .output_arg )
232- # TODO: We also have permissible dtype information:
233- # from objects in the array `self.obj.attr` under the field
234- # `allowed_values`.
249+
235250 return out_meta_types
236251
237- def input_args (self , * args , ** kwargs ):
252+ def input_args (self , * args , apply_defaults = True , ** kwargs ):
253+ """Return a list of arguments for this OpDef's 'apply function'."""
238254 kwargs = OrderedDict (
239255 (k , v )
240256 for k , v in kwargs .items ()
241257 # Filter out the optional keyword arguments so they we only pass
242258 # expected arguments to the `OpDef`'s apply function.
243259 if k in self ._apply_func_sig .parameters
244260 )
261+
245262 op_args = self ._apply_func_sig .bind (* args , ** kwargs )
246- op_args .apply_defaults ()
263+
264+ if apply_defaults :
265+ op_args .apply_defaults ()
266+
247267 return op_args .arguments
248268
249269 def __call__ (self , * args , ** kwargs ):
250270 """Create the meta object(s) resulting from an application of this `OpDef`'s implied `Operation`."""
251- op_args , op_args_unreified = meta_reify_iter (args )
252- op_kwargs , op_kwargs_unreified = meta_reify_iter (kwargs )
271+
272+ if not meta ._auto_reification_disabled :
273+ op_args , op_args_unreified = meta_reify_iter (args )
274+ op_kwargs , op_kwargs_unreified = meta_reify_iter (kwargs )
275+ else :
276+ op_args , op_args_unreified = args , True
277+ op_kwargs , op_kwargs_unreified = kwargs , True
278+
253279 apply_arguments = self .input_args (* op_args , ** op_kwargs )
254280
255- if not op_args_unreified and not op_kwargs_unreified :
281+ if not ( op_args_unreified or op_kwargs_unreified ) :
256282
257283 # them into meta objects. Doing so will yield information we
258284 # wouldn't be able to produce otherwise (e.g. shape info).
@@ -269,29 +295,43 @@ def __call__(self, *args, **kwargs):
269295
270296 tf_out = self ._apply_func (** apply_arguments )
271297 res_var = metatize (tf_out )
272- return res_var
273-
274- #
275- # If we're here, that means we have to create the meta objects
276- # manually.
277- #
278- # TODO: `tf.placeholder`s are pretty flexible, we could probably use
279- # one as a stand-in for any un-reified tensor arguments and at least
280- # get some partial `dtype`, `shape` and `name` info.
281-
282- op_input_args = tuple (
283- apply_arguments .get (i .name ) for i in self .obj .input_arg if i .name in apply_arguments
284- )
285298
286- node_attr = {a .name : apply_arguments .get (a .name , a ) for a in self .obj .attr }
299+ if "names" in meta ._lvar_defaults_enabled :
300+ # This should also reset the NodeDef's `obj`
301+ res_var .op .node_def .name = var ()
302+ res_var .op .reset ()
303+ res_var .reset ()
287304
288- op_name = op_kwargs .get ("name" , self .obj .name )
305+ if "node_attrs" in meta ._lvar_defaults_enabled :
306+ # This should also reset the NodeDef's `obj`
307+ res_var .op .node_def .attr = var ()
308+ res_var .op .reset ()
309+ res_var .reset ()
289310
290- node_def = TFlowMetaNodeDef (self .obj .name , op_name , node_attr )
311+ else :
312+ #
313+ # If we're here, that means we have to create the meta objects
314+ # manually.
315+ #
291316
292- op_mt = TFlowMetaOp (self , node_def , op_input_args )
317+ op_input_args = tuple (
318+ apply_arguments .get (i .name ) for i in self .obj .input_arg if i .name in apply_arguments
319+ )
320+
321+ if "node_attrs" not in meta ._lvar_defaults_enabled :
322+ node_attr = {a .name : apply_arguments .get (a .name , a ) for a in self .obj .attr }
323+ else :
324+ node_attr = var ()
325+
326+ op_name = op_kwargs .get (
327+ "name" , self .obj .name if "names" not in meta ._lvar_defaults_enabled else var ()
328+ )
329+
330+ node_def = TFlowMetaNodeDef (self .obj .name , op_name , node_attr )
293331
294- res_var = op_mt .default_output
332+ op_mt = TFlowMetaOp (self , node_def , op_input_args )
333+
334+ res_var = op_mt .default_output
295335
296336 return res_var
297337
@@ -517,16 +557,26 @@ def outputs(self):
517557 if getattr (self , "_outputs" , None ) is not None :
518558 return self ._outputs
519559
520- if (
521- isvar (self .op_def )
522- or isvar (self .inputs )
523- or isvar (self .node_def )
524- or isvar (self .node_def .attr )
525- ):
560+ if isvar (self .op_def ):
526561 self ._outputs = var ()
527562 else :
528563
529- apply_arguments = self .op_def .input_args (* self .inputs , ** self .node_def .attr )
564+ if isvar (self .node_def ) or isvar (getattr (self .node_def , "attr" )):
565+ node_attr = {}
566+ else :
567+ node_attr = self .node_def .attr
568+
569+ if isvar (self .inputs ):
570+ inputs = (None ,) * len (self .op_def ._apply_func_sig .parameters )
571+ apply_defaults = False
572+ else :
573+ inputs = self .inputs
574+ apply_defaults = True
575+
576+ apply_arguments = self .op_def .input_args (
577+ * inputs , apply_defaults = apply_defaults , ** node_attr
578+ )
579+
530580 out_types_mt = self .op_def .out_meta_types (
531581 inputs = apply_arguments , node_def = self .node_def
532582 )
@@ -551,7 +601,7 @@ def default_output(self):
551601
552602 mt_outs = self .outputs
553603
554- if len (mt_outs ) == 1 :
604+ if isinstance ( mt_outs , Sequence ) and len (mt_outs ) == 1 :
555605 out_var = mt_outs [0 ]
556606 else :
557607 out_var = mt_outs
0 commit comments