1616
1717from google .protobuf .message import Message
1818
19- from tensorflow .python .framework import tensor_util , op_def_registry , op_def_library , tensor_shape
19+ from tensorflow .python .framework import (
20+ tensor_util ,
21+ op_def_registry ,
22+ op_def_library ,
23+ tensor_shape ,
24+ ops ,
25+ )
2026from tensorflow .core .framework .op_def_pb2 import OpDef
2127from tensorflow .core .framework .node_def_pb2 import NodeDef
2228
23- # from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto
24-
2529from tensorflow_probability import distributions as tfd
2630
2731
3034 MetaSymbolType ,
3135 MetaOp ,
3236 MetaVariable ,
37+ MetaReificationError ,
3338 meta_reify_iter ,
3439 _metatize ,
3540 metatize ,
@@ -61,51 +66,53 @@ class MetaOpDefLibrary(object):
6166 opdef_signatures = {}
6267
6368 @classmethod
64- def apply_op (cls , * args , ** kwargs ):
65- return op_def_library .apply_op (* args , ** kwargs )
69+ def get_op_info (cls , opdef ):
70+ """Return the TF Python API function signature for a given `OpDef`.
71+
72+ Parameter
73+ ---------
74+ opdef: str or `OpDef` object (meta or base)
75+ """
76+ if isinstance (opdef , str ):
77+ opdef_name = opdef
78+ opdef = op_def_registry .get (opdef_name )
79+ else :
80+ opdef_name = opdef .name
81+
82+ opdef_sig = cls .opdef_signatures .get (opdef_name , None )
83+
84+ if opdef_sig is None and opdef is not None :
85+ opdef_func = getattr (tf .raw_ops , opdef .name , None )
86+ opdef_sig = cls .make_opdef_sig (opdef , opdef_func )
87+ cls .opdef_signatures [opdef .name ] = opdef_sig
88+
89+ return opdef_sig
6690
6791 @classmethod
6892 def make_opdef_sig (cls , opdef , opdef_py_func = None ):
6993 """Create a `Signature` object for an `OpDef`.
7094
7195 Annotations are include so that one can partially verify arguments.
7296 """
73- input_args = OrderedDict ([(a .name , a .type or a .type_attr ) for a in opdef .input_arg ])
74- attrs = OrderedDict ([(a .name , a ) for a in opdef .attr ])
75-
76- params = OrderedDict ()
7797 if opdef_py_func :
98+ #
7899 # We assume we're dealing with a function from `tf.raw_ops`.
79- # Those functions have only the necessary `input_arg`s and
80- # `attr` inputs as arguments.
100+ # Those functions have only the necessary `input_arg`s and `attr`
101+ # inputs as arguments.
102+ #
81103 opdef_func_sig = Signature .from_callable (opdef_py_func )
82104 params = opdef_func_sig .parameters
83105
84- # for name, param in opdef_func_sig.parameters.items():
85- # # We make positional parameters permissible (since the
86- # # functions in `tf.raw_ops` are keyword-only), and we use the
87- # # `tf.raw_ops` arguments to determine the *actual* required
88- # # arguments (because `OpDef`'s `input_arg`s and `attrs` aren't
89- # # exactly clear about that).
90- # if name in input_args:
91- # new_default = Parameter.empty
92- # new_annotation = input_args[name]
93- # else:
94- # new_default = None
95- # new_annotation = attrs.get(name, None)
96- # if new_annotation is not None:
97- # new_annotation = new_annotation.type
106+ else :
107+ #
108+ # We're crafting an `Operation` at a low-level via `apply_op`
109+ # (like the functions in `tf.raw_ops` do)
98110 #
99- # new_param = param.replace(
100- # kind=Parameter.POSITIONAL_OR_KEYWORD,
101- # default=new_default,
102- # annotation=new_annotation,
103- # )
104- # params[name] = new_param
111+ input_args = OrderedDict ([(a .name , a .type or a .type_attr ) for a in opdef .input_arg ])
112+ attrs = OrderedDict ([(a .name , a ) for a in opdef .attr ])
113+ params = OrderedDict ()
105114
106- else :
107- # We're crafting the Operation at a low-level via `apply_op`.
108- opdef_py_func = partial (op_def_lib .apply_op , opdef .name )
115+ opdef_py_func = partial (op_def_library .apply_op , opdef .name )
109116
110117 for i_name , i_type in input_args .items ():
111118 p = Parameter (i_name , Parameter .POSITIONAL_OR_KEYWORD , annotation = i_type )
@@ -144,29 +151,6 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None):
144151 )
145152 return opdef_sig , opdef_py_func
146153
147- @classmethod
148- def get_op_info (cls , opdef ):
149- """Return the TF Python API function signature for a given `OpDef`.
150-
151- Parameter
152- ---------
153- opdef: str or `OpDef` object (meta or base)
154- """
155- if isinstance (opdef , str ):
156- opdef_name = opdef
157- opdef = op_def_registry .get (opdef_name )
158- else :
159- opdef_name = opdef .name
160-
161- opdef_sig = cls .opdef_signatures .get (opdef_name , None )
162-
163- if opdef_sig is None and opdef is not None :
164- opdef_func = getattr (tf .raw_ops , opdef .name , None )
165- opdef_sig = cls .make_opdef_sig (opdef , opdef_func )
166- cls .opdef_signatures [opdef .name ] = cls .make_opdef_sig (opdef , opdef_func )
167-
168- return opdef_sig
169-
170154
171155op_def_lib = MetaOpDefLibrary ()
172156
@@ -183,7 +167,6 @@ def _metatize_tf_object(obj):
183167def load_dispatcher ():
184168 """Set/override dispatcher to default to TF objects."""
185169
186- from tensorflow .python .framework .ops import EagerTensor
187170 from tensorflow .python .ops .gen_linalg_ops import _SvdOutput
188171
189172 def _metatize_tf_svd (obj ):
@@ -200,7 +183,7 @@ def _metatize_tf_eager(obj):
200183 " (e.g. within `tensorflow.python.eager.context.graph_mode`)"
201184 )
202185
203- meta ._metatize .add ((EagerTensor ,), _metatize_tf_eager )
186+ meta ._metatize .add ((ops . EagerTensor ,), _metatize_tf_eager )
204187
205188 meta ._metatize .add ((object ,), _metatize_tf_object )
206189 meta ._metatize .add ((HashableNDArray ,), _metatize_tf_object )
@@ -599,12 +582,30 @@ def reify(self):
599582 )
600583
601584 if not (op_inputs_unreified or op_attrs_unreified or isvar (self .name )):
602-
603- apply_arguments = operator .input_args (* op_inputs , name = self .name , ** op_attrs )
604- tf_out = operator ._apply_func (** apply_arguments )
605- op_tf = tf_out .op
606-
607- # TODO: Update NodeDef attrs?
585+ #
586+ # An operation with this name might already exist in the graph
587+ #
588+ try :
589+ existing_op = ops .get_default_graph ().get_operation_by_name (self .name )
590+ except KeyError :
591+ #
592+ # There is no such `Operation`, so we attempt to create it
593+ #
594+ apply_arguments = operator .input_args (* op_inputs , name = self .name , ** op_attrs )
595+ tf_out = operator ._apply_func (** apply_arguments )
596+ op_tf = tf_out .op
597+ else :
598+ #
599+ # An `Operation` with this name exists, let's make sure it's
600+ # equivalent to this meta `Operation`
601+ #
602+ if self != mt (existing_op ):
603+ raise MetaReificationError (
604+ f"An Operation with the name { self .name } "
605+ " already exists in the graph and is not"
606+ " equal to this meta object."
607+ )
608+ op_tf = existing_op
608609
609610 assert op_tf is not None
610611 self ._obj = op_tf
@@ -1149,4 +1150,5 @@ def __getattr__(self, obj):
11491150
11501151mt = TFlowMetaAccessor ()
11511152
1153+
11521154load_dispatcher ()
0 commit comments