Skip to content

Commit b0944df

Browse files
Do not remove TF NodeDef attributes under enable_lvar_defaults
The `NodeDef` attributes of a "Const" operator hold the distinguishing constant value, so removing them results in a meta object that no longer corresponds to a specific constant.
1 parent 16e369c commit b0944df

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ class TFlowMetaNodeDef(TFlowMetaSymbol):
407407
def _metatize(cls, obj):
408408
res = super()._metatize(obj)
409409

410-
if "node_attrs" in meta._lvar_defaults_enabled:
410+
if obj.op != "Const" and "node_attrs" in meta._lvar_defaults_enabled:
411411
res.attr = var()
412412

413413
if "names" in meta._lvar_defaults_enabled:

tests/tensorflow/test_meta.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,10 @@ def test_global_options():
594594
assert isvar(b_mt.op.node_def.attr)
595595
assert b_mt.op.inputs[1] is a_mt
596596

597+
# `NodeDef.attr` for constants should not be turned into lvars
598+
assert not isvar(b_mt.op.inputs[0].op.node_def.attr)
599+
assert not isvar(b_mt.op.inputs[1].op.node_def.attr)
600+
597601
# Make sure we clear out the `.obj` so that the names won't mismatch
598602
with tf.Graph().as_default(), enable_lvar_defaults('names'):
599603
a_mt = mt(1.0)

0 commit comments

Comments
 (0)