Skip to content

Commit 16e369c

Browse files
Validate and reset altered metatized objects
1 parent 36278dc commit 16e369c

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,23 @@ def _metatize_tf_eager(obj):
189189
class TFlowMetaSymbol(MetaSymbol):
190190
__slots__ = ()
191191

192+
@classmethod
193+
def _metatize(cls, obj):
194+
195+
res = super()._metatize(obj)
196+
res.validate_objs()
197+
198+
return res
199+
200+
def validate_objs(self):
201+
# If there is no base object associated with the inputs, then we can't
202+
# trust a base object associated with this object (e.g. for the case in
203+
# which metatize altered a property in an input).
204+
for prop in self.rands():
205+
if isinstance(prop, MetaSymbol) and prop.obj is None:
206+
self.reset()
207+
break
208+
192209

193210
class OpDefFactoryType(MetaSymbolType):
194211
__opdefs__ = {}
@@ -501,8 +518,7 @@ def _metatize(cls, obj):
501518
]
502519
res = cls(*new_args, obj=obj)
503520

504-
if meta._lvar_defaults_enabled.issuperset(["node_attrs", "names"]):
505-
res.reset()
521+
res.validate_objs()
506522

507523
return res
508524

@@ -688,13 +704,8 @@ class TFlowMetaTensor(TFlowMetaSymbol, MetaVariable):
688704
@classmethod
689705
@cachedmethod(lambda cls: tf_metatize_cache)
690706
def _metatize(cls, obj):
691-
692-
res = super()._metatize(obj)
693-
694-
if meta._lvar_defaults_enabled.issuperset(["node_attrs", "names"]):
695-
res.reset()
696-
697-
return res
707+
"""Cache Tensors specifically."""
708+
return super()._metatize(obj)
698709

699710
def __init__(self, op, value_index, dtype, obj=None):
700711
self.op = metatize(op)

tests/tensorflow/test_meta.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,3 +593,8 @@ def test_global_options():
593593
assert isvar(b_mt.name)
594594
assert isvar(b_mt.op.node_def.attr)
595595
assert b_mt.op.inputs[1] is a_mt
596+
597+
# Make sure we clear out the `.obj` so that the names won't mismatch
598+
with tf.Graph().as_default(), enable_lvar_defaults('names'):
599+
a_mt = mt(1.0)
600+
assert isvar(a_mt.name)

0 commit comments

Comments
 (0)