Skip to content

Commit 6f7f95e

Browse files
Merge pull request #83 from brandonwillard/fix-altered-metatized-objects
Fix altered metatized objects
2 parents 36278dc + 900be8b commit 6f7f95e

File tree

2 files changed

+85
-25
lines changed

2 files changed

+85
-25
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@
4343

4444

4545
class MetaOpDefLibrary(object):
46+
"""A singleton-like object that holds correspondences between TF Python API functions and the `OpDef`s they construct.
47+
48+
It provides a map of `OpDef` names (lower-cased) to the Python API
49+
functions in `tensorflow.raw_ops`, as well as `inspect.Signature` objects
50+
for said functions so that default values and lists of arguments (keywords
51+
included) can be more easily used.
52+
53+
"""
4654

4755
lower_op_name_to_raw = {
4856
op_name.lower(): op_name
@@ -130,6 +138,12 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None):
130138

131139
@classmethod
132140
def get_op_info(cls, opdef):
141+
"""Return the TF Python API function signature for a given `OpDef`.
142+
143+
Parameter
144+
---------
145+
opdef: str or `OpDef` object (meta or base)
146+
"""
133147
if isinstance(opdef, str):
134148
opdef_name = opdef
135149
opdef = op_def_registry.get(opdef_name)
@@ -189,6 +203,23 @@ def _metatize_tf_eager(obj):
189203
class TFlowMetaSymbol(MetaSymbol):
190204
__slots__ = ()
191205

206+
@classmethod
207+
def _metatize(cls, obj):
208+
209+
res = super()._metatize(obj)
210+
res.validate_objs()
211+
212+
return res
213+
214+
def validate_objs(self):
215+
# If there is no base object associated with the inputs, then we can't
216+
# trust a base object associated with this object (e.g. for the case in
217+
# which metatize altered a property in an input).
218+
for prop in self.rands():
219+
if isinstance(prop, MetaSymbol) and prop.obj is None:
220+
self.reset()
221+
break
222+
192223

193224
class OpDefFactoryType(MetaSymbolType):
194225
__opdefs__ = {}
@@ -272,7 +303,14 @@ def input_args(self, *args, apply_defaults=True, **kwargs):
272303
return op_args.arguments
273304

274305
def __call__(self, *args, **kwargs):
275-
"""Create the meta object(s) resulting from an application of this `OpDef`'s implied `Operation`."""
306+
"""Create the meta object(s) using the TF Python API's operator functions.
307+
308+
Each meta `OpDef` is associated with a TF Python function
309+
(`self._apply_func`) that is used to construct its `Operation`s.
310+
311+
See `TFlowMetaTensor.operator` and `TFlowMetaTensor.operator`.
312+
313+
"""
276314

277315
apply_arguments = self.input_args(*args, **kwargs)
278316

@@ -368,8 +406,7 @@ def __eq__(self, other):
368406
if not (type(self) == type(other)):
369407
return False
370408

371-
if not (self.base == other.base):
372-
return False
409+
assert self.base == other.base
373410

374411
return self.obj.name == other.obj.name
375412

@@ -390,7 +427,7 @@ class TFlowMetaNodeDef(TFlowMetaSymbol):
390427
def _metatize(cls, obj):
391428
res = super()._metatize(obj)
392429

393-
if "node_attrs" in meta._lvar_defaults_enabled:
430+
if obj.op != "Const" and "node_attrs" in meta._lvar_defaults_enabled:
394431
res.attr = var()
395432

396433
if "names" in meta._lvar_defaults_enabled:
@@ -501,8 +538,7 @@ def _metatize(cls, obj):
501538
]
502539
res = cls(*new_args, obj=obj)
503540

504-
if meta._lvar_defaults_enabled.issuperset(["node_attrs", "names"]):
505-
res.reset()
541+
res.validate_objs()
506542

507543
return res
508544

@@ -688,13 +724,8 @@ class TFlowMetaTensor(TFlowMetaSymbol, MetaVariable):
688724
@classmethod
689725
@cachedmethod(lambda cls: tf_metatize_cache)
690726
def _metatize(cls, obj):
691-
692-
res = super()._metatize(obj)
693-
694-
if meta._lvar_defaults_enabled.issuperset(["node_attrs", "names"]):
695-
res.reset()
696-
697-
return res
727+
"""Cache Tensors specifically."""
728+
return super()._metatize(obj)
698729

699730
def __init__(self, op, value_index, dtype, obj=None):
700731
self.op = metatize(op)
@@ -734,28 +765,44 @@ def name(self):
734765

735766
@property
736767
def operator(self):
768+
"""Return the meta OpDef for this tensor.
769+
770+
Since meta OpDefs are callable (and dispatch to the corresponding TF
771+
Python API function), this object called with arguments provided by
772+
`TFlowMetaTensor.inputs` recreates the underlying tensor using the TF
773+
Python interface. This approach has advantages over the purely
774+
graph-level approach to constructing meta objects, because--when all
775+
arguments are reifiable--it allows us to use purely TF means to
776+
construct a meta object (i.e. by first constructing the base object and
777+
then "metatizing" it).
778+
779+
Meta objects produced this way result in less unknown information
780+
(e.g. dtypes and shapes) and have the same default values as their base
781+
object counterparts (e.g. `Operator` names and `NodeDef.attr` values).
782+
"""
737783
if self.op is not None and not isvar(self.op):
738784
return self.op.op_def
739785

740786
@property
741787
def inputs(self):
742-
"""Return the tensor's inputs/rands.
788+
"""Return the inputs necessary to recreate this object using its TF Python API function.
789+
790+
These inputs differ from `self.op.inputs` primarily in that they
791+
contain the `node_def` parameters as keywords (e.g. to Python API
792+
functions like `tf.add`).
743793
744-
NOTE: These inputs differ from `self.op.inputs` in that they contain
745-
the `node_def` parameters, as well.
746-
In other words, these can be used to recreate this object (per
747-
the meta object spec).
794+
See `TFlowMetaTensor.operator` for more information.
748795
"""
749796
# TODO: In keeping with our desire to return logic variables in cases
750797
# where params aren't given/inferred, we could return something like
751798
# `cons(var(), var())` here (although that wouldn't be necessarily imply
752799
# that the result is a proper list/tuple).
753-
if self.op is not None and not isvar(self.op):
754-
input_args = self.op.op_def.input_args(
755-
*self.op.inputs,
756-
name=self.op.name if not isvar(self.op.name) else None,
757-
**self.op.node_def.attr,
758-
)
800+
if self.op is not None and not isvar(self.op) and not isvar(self.op.inputs):
801+
if not isvar(self.op.node_def) and not isvar(self.op.node_def.attr):
802+
attr = self.op.node_def.attr
803+
else:
804+
attr = {}
805+
input_args = self.op.op_def.input_args(*self.op.inputs, name=self.op.name, **attr)
759806
return tuple(input_args.values())
760807

761808
def reify(self):

tests/tensorflow/test_meta.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def test_meta_eager():
6464
@run_in_graph_mode
6565
def test_meta_basic():
6666

67+
assert mt.Add == mt.Add
68+
assert mt.Add != mt.Sub
69+
6770
var_mt = TFlowMetaTensor(var(), var(), var())
6871
# It should generate a logic variable for the name and use from here on.
6972
var_name = var_mt.name
@@ -248,9 +251,10 @@ def test_meta_lvars():
248251
assert all(isvar(getattr(tn_mt, s)) for s in tn_mt.__all_props__)
249252
assert isinstance(tn_mt.reify(), TFlowMetaTensor)
250253

251-
mo_mt = TFlowMetaOp(mt.Add, [tn_mt, tn_mt], var())
254+
mo_mt = TFlowMetaOp(mt.Add, var(), [tn_mt, var('a')])
252255
assert len(mo_mt.outputs) == 1
253256
assert isinstance(mo_mt.reify(), TFlowMetaOp)
257+
assert mo_mt.outputs[0].inputs == (tn_mt, var('a'), mo_mt.name)
254258

255259

256260
@pytest.mark.usefixtures("run_with_tensorflow")
@@ -593,3 +597,12 @@ def test_global_options():
593597
assert isvar(b_mt.name)
594598
assert isvar(b_mt.op.node_def.attr)
595599
assert b_mt.op.inputs[1] is a_mt
600+
601+
# `NodeDef.attr` for constants should not be turned into lvars
602+
assert not isvar(b_mt.op.inputs[0].op.node_def.attr)
603+
assert not isvar(b_mt.op.inputs[1].op.node_def.attr)
604+
605+
# Make sure we clear out the `.obj` so that the names won't mismatch
606+
with tf.Graph().as_default(), enable_lvar_defaults('names'):
607+
a_mt = mt(1.0)
608+
assert isvar(a_mt.name)

0 commit comments

Comments
 (0)