99from inspect import Parameter , Signature
1010
1111from collections import OrderedDict
12- from collections .abc import Sequence
1312
1413from functools import partial
1514
@@ -220,11 +219,11 @@ def __init__(self, obj=None):
220219 super ().__init__ (obj = obj )
221220 self ._apply_func_sig , self ._apply_func = op_def_lib .get_op_info (obj )
222221
223- def out_meta_types (self , inputs = None ):
222+ def out_meta_types (self , inputs = None , node_def = None ):
224223 def _convert_outputs (o ):
225- if o .type_attr == "T" :
226- return (TFlowMetaTensor , var ())
227- elif o .type_attr == "dtype" :
224+ if o .type_attr == "T" and node_def :
225+ return (TFlowMetaTensor , node_def . attr . get ( "T" , var () ))
226+ elif o .type_attr == "dtype" and inputs :
228227 return (TFlowMetaTensor , inputs .get ("dtype" , var ()))
229228 else :
230229 return (TFlowMetaTensor , var ())
@@ -284,7 +283,6 @@ def __call__(self, *args, **kwargs):
284283 apply_arguments .get (i .name ) for i in self .obj .input_arg if i .name in apply_arguments
285284 )
286285
287- # Get the `OpDef`-instantiating parameters and call them a "node_def".
288286 node_attr = {a .name : apply_arguments .get (a .name , a ) for a in self .obj .attr }
289287
290288 op_name = op_kwargs .get ("name" , self .obj .name )
@@ -346,6 +344,8 @@ def _protobuf_convert(cls, k, v):
346344 return metatize (tensor_shape .as_shape (v .shape ))
347345 elif k == "dtype" :
348346 return tf .as_dtype (v .type ).name
347+ elif k == "T" :
348+ return tf .as_dtype (v .type ).name
349349 elif k == "value" :
350350 return tensor_util .MakeNdarray (v .tensor )
351351 else :
@@ -364,22 +364,17 @@ def __init__(self, op, name, attr, obj=None):
364364 self .name = name if isvar (name ) else str (name )
365365
366366 if not isvar (attr ):
367- # We want to limit the attributes we'll consider to those that show
368- # up in an OpDef function's signature (e.g. ignore info about
369- # permissible types).
370367 opdef_sig , _ = op_def_lib .get_op_info (self .op )
371- op_param_names = opdef_sig .parameters .keys ()
372-
373368 _attr = dict ()
369+
374370 for k , v in attr .items ():
375371 if isinstance (v , Message ):
376372 try :
377373 v = self ._protobuf_convert (k , v )
378374 except TypeError :
379- continue
375+ v = var ()
380376
381- if k != "T" and k in op_param_names :
382- _attr [k ] = v
377+ _attr [k ] = v
383378
384379 self .attr = _attr
385380 else :
@@ -532,11 +527,12 @@ def outputs(self):
532527 else :
533528
534529 apply_arguments = self .op_def .input_args (* self .inputs , ** self .node_def .attr )
535- out_types_mt = self .op_def .out_meta_types (inputs = apply_arguments )
530+ out_types_mt = self .op_def .out_meta_types (
531+ inputs = apply_arguments , node_def = self .node_def
532+ )
536533
537534 mt_outs = tuple (
538- o_type (self , i , var () if o_dtype is None else o_dtype )
539- for i , (o_type , o_dtype ) in enumerate (out_types_mt )
535+ o_type (self , i , o_dtype ) for i , (o_type , o_dtype ) in enumerate (out_types_mt )
540536 )
541537
542538 self ._outputs = mt_outs
@@ -574,7 +570,15 @@ def reify(self):
574570 if isvar (self .node_def ):
575571 return self
576572
577- op_attrs , op_attrs_unreified = meta_reify_iter (self .node_def .attr )
573+ op_attrs , op_attrs_unreified = meta_reify_iter (
574+ # Only use NodeDef attrs that appear in the OpDef's call signature.
575+ # Other NodeDef attrs, like dtype and shape, can be computed.
576+ {
577+ k : v
578+ for k , v in self .node_def .attr .items ()
579+ if k in self .op_def ._apply_func_sig .parameters
580+ }
581+ )
578582
579583 if not (op_inputs_unreified or op_attrs_unreified or MetaSymbol .is_meta (self .name )):
580584
@@ -587,6 +591,8 @@ def reify(self):
587591 tf_out = self .op_def ._apply_func (** apply_arguments )
588592 op_tf = tf_out .op
589593
594+ # TODO: Update NodeDef attrs?
595+
590596 assert op_tf is not None
591597 self ._obj = op_tf
592598 return self .obj
@@ -623,14 +629,8 @@ def name(self):
623629
624630 if self .obj is not None and not isinstance (self .obj , Var ):
625631 name = self .obj .name
626- elif (
627- self .op is not None
628- and not isvar (self .op )
629- and not isvar (self .op .name )
630- and not isinstance (self .op .outputs , Sequence )
631- ):
632- out_num = self .op .outputs .index (self )
633- name = f"{ self .op .name } :{ out_num } "
632+ elif isinstance (getattr (self .op , "name" , None ), str ) and not isvar (self .value_index ):
633+ name = f"{ self .op .name } :{ self .value_index } "
634634 else :
635635 name = var ()
636636
0 commit comments