Skip to content

Commit 7c80d10

Browse files
Fix TF logic variable name and NodeDef handling
1 parent 3b905bb commit 7c80d10

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def out_meta_types(self, inputs=None, node_def=None):
235235
"""Return a list of tuples containing object types and corresponding dtypes for the outputs of this OpDef."""
236236

237237
def _convert_outputs(o):
238-
if o.type_attr == "T" and hasattr(node_def, "attr"):
238+
if o.type_attr == "T" and isinstance(getattr(node_def, "attr", None), dict):
239239
return (TFlowMetaTensor, node_def.attr.get("T", var()))
240240
elif o.type_attr == "dtype" and inputs:
241241
return (TFlowMetaTensor, inputs.get("dtype", var()))
@@ -323,9 +323,10 @@ def __call__(self, *args, **kwargs):
323323
else:
324324
node_attr = var()
325325

326-
op_name = op_kwargs.get(
327-
"name", self.obj.name if "names" not in meta._lvar_defaults_enabled else var()
328-
)
326+
if "names" not in meta._lvar_defaults_enabled:
327+
op_name = op_kwargs.get("name", self.obj.name)
328+
else:
329+
op_name = var()
329330

330331
node_def = TFlowMetaNodeDef(self.obj.name, op_name, node_attr)
331332

@@ -561,7 +562,7 @@ def outputs(self):
561562
self._outputs = var()
562563
else:
563564

564-
if isvar(self.node_def) or isvar(getattr(self.node_def, "attr")):
565+
if isvar(self.node_def) or not isinstance(getattr(self.node_def, "attr", None), dict):
565566
node_attr = {}
566567
else:
567568
node_attr = self.node_def.attr
@@ -613,24 +614,23 @@ def reify(self):
613614
if self.obj and not isinstance(self.obj, Var):
614615
return self.obj
615616

616-
# tt_op = self.op.reify()
617-
# if not self.is_meta(tt_op):
617+
if isvar(self.inputs):
618+
return self
619+
618620
op_inputs, op_inputs_unreified = meta_reify_iter(self.inputs)
619621

620-
if isvar(self.node_def):
622+
node_attr = getattr(self.node_def, "attr", None)
623+
624+
if node_attr is None or isvar(node_attr):
621625
return self
622626

623627
op_attrs, op_attrs_unreified = meta_reify_iter(
624628
# Only use NodeDef attrs that appear in the OpDef's call signature.
625629
# Other NodeDef attrs, like dtype and shape, can be computed.
626-
{
627-
k: v
628-
for k, v in self.node_def.attr.items()
629-
if k in self.op_def._apply_func_sig.parameters
630-
}
630+
{k: v for k, v in node_attr.items() if k in self.op_def._apply_func_sig.parameters}
631631
)
632632

633-
if not (op_inputs_unreified or op_attrs_unreified or MetaSymbol.is_meta(self.name)):
633+
if not (op_inputs_unreified or op_attrs_unreified or isvar(self.name)):
634634

635635
# We have to use a primitive string or TF will complain.
636636
name = self.name

tests/tensorflow/test_meta.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ def test_meta_lvars():
224224

225225
nd_mt = TFlowMetaNodeDef(var(), var(), var())
226226
assert all(isvar(getattr(nd_mt, s)) for s in nd_mt.__all_props__)
227+
# TODO: Figure out how we want this to work.
228+
# assert isinstance(nd_mt.reify(), TFlowMetaNodeDef)
227229

228230
mo_mt = TFlowMetaOp(var(), var(), var(), var())
229231
assert all(isvar(getattr(mo_mt, s)) for s in mo_mt.__all_props__)
@@ -234,14 +236,21 @@ def test_meta_lvars():
234236

235237
mo_mt = TFlowMetaOp(mt.Add, var(), var())
236238
assert len(mo_mt.outputs) == 1
239+
assert isinstance(mo_mt.reify(), TFlowMetaOp)
237240

238241
ts_mt = TFlowMetaTensorShape(var())
239242
assert all(isvar(getattr(ts_mt, s)) for s in ts_mt.__all_props__)
243+
assert isinstance(ts_mt.reify(), TFlowMetaTensorShape)
240244

241245
assert isvar(ts_mt.as_list())
242246

243247
tn_mt = TFlowMetaTensor(var(), var(), var())
244248
assert all(isvar(getattr(tn_mt, s)) for s in tn_mt.__all_props__)
249+
assert isinstance(tn_mt.reify(), TFlowMetaTensor)
250+
251+
mo_mt = TFlowMetaOp(mt.Add, [tn_mt, tn_mt], var())
252+
assert len(mo_mt.outputs) == 1
253+
assert isinstance(mo_mt.reify(), TFlowMetaOp)
245254

246255

247256
@pytest.mark.usefixtures("run_with_tensorflow")
@@ -554,10 +563,19 @@ def test_global_options():
554563
assert isvar(z_mt.name)
555564
assert isvar(z_mt.op.node_def.attr)
556565

557-
with tf.Graph().as_default(), disable_auto_reification(), enable_lvar_defaults('names', 'node_attrs'):
566+
with disable_auto_reification(), enable_lvar_defaults('names', 'node_attrs'):
558567
# This will *not* auto-reify and simply create the object from scratch with meta types
559568
# and the appropriate/desired logic variables.
560569
z_mt = mt.Placeholder('float')
561570
assert z_mt.obj is None
562571
assert isvar(z_mt.name)
563572
assert isvar(z_mt.op.node_def.attr)
573+
574+
with tf.Graph().as_default(), enable_lvar_defaults('names', 'node_attrs'):
575+
y_mt = mt.Placeholder('float') + mt.Placeholder('float')
576+
assert isvar(y_mt.name)
577+
assert isvar(y_mt.op.inputs[0].name)
578+
assert isvar(y_mt.op.inputs[1].name)
579+
assert isvar(y_mt.op.node_def.attr)
580+
assert isvar(y_mt.op.inputs[0].op.node_def.attr)
581+
assert isvar(y_mt.op.inputs[1].op.node_def.attr)

0 commit comments

Comments
 (0)