Skip to content

Commit d644e0b

Browse files
Fix TF tensor name construction and add NodeDef dtype information
NodeDefs can carry dtype information in their "T" attr. Since we weren't using that attr, we weren't able to determine dtypes--in some cases--during the construction of outputs for meta Operations.
1 parent 8077d64 commit d644e0b

File tree

2 files changed

+42
-28
lines changed

2 files changed

+42
-28
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from inspect import Parameter, Signature
1010

1111
from collections import OrderedDict
12-
from collections.abc import Sequence
1312

1413
from functools import partial
1514

@@ -220,11 +219,11 @@ def __init__(self, obj=None):
220219
super().__init__(obj=obj)
221220
self._apply_func_sig, self._apply_func = op_def_lib.get_op_info(obj)
222221

223-
def out_meta_types(self, inputs=None):
222+
def out_meta_types(self, inputs=None, node_def=None):
224223
def _convert_outputs(o):
225-
if o.type_attr == "T":
226-
return (TFlowMetaTensor, var())
227-
elif o.type_attr == "dtype":
224+
if o.type_attr == "T" and node_def:
225+
return (TFlowMetaTensor, node_def.attr.get("T", var()))
226+
elif o.type_attr == "dtype" and inputs:
228227
return (TFlowMetaTensor, inputs.get("dtype", var()))
229228
else:
230229
return (TFlowMetaTensor, var())
@@ -284,7 +283,6 @@ def __call__(self, *args, **kwargs):
284283
apply_arguments.get(i.name) for i in self.obj.input_arg if i.name in apply_arguments
285284
)
286285

287-
# Get the `OpDef`-instantiating parameters and call them a "node_def".
288286
node_attr = {a.name: apply_arguments.get(a.name, a) for a in self.obj.attr}
289287

290288
op_name = op_kwargs.get("name", self.obj.name)
@@ -346,6 +344,8 @@ def _protobuf_convert(cls, k, v):
346344
return metatize(tensor_shape.as_shape(v.shape))
347345
elif k == "dtype":
348346
return tf.as_dtype(v.type).name
347+
elif k == "T":
348+
return tf.as_dtype(v.type).name
349349
elif k == "value":
350350
return tensor_util.MakeNdarray(v.tensor)
351351
else:
@@ -364,22 +364,17 @@ def __init__(self, op, name, attr, obj=None):
364364
self.name = name if isvar(name) else str(name)
365365

366366
if not isvar(attr):
367-
# We want to limit the attributes we'll consider to those that show
368-
# up in an OpDef function's signature (e.g. ignore info about
369-
# permissible types).
370367
opdef_sig, _ = op_def_lib.get_op_info(self.op)
371-
op_param_names = opdef_sig.parameters.keys()
372-
373368
_attr = dict()
369+
374370
for k, v in attr.items():
375371
if isinstance(v, Message):
376372
try:
377373
v = self._protobuf_convert(k, v)
378374
except TypeError:
379-
continue
375+
v = var()
380376

381-
if k != "T" and k in op_param_names:
382-
_attr[k] = v
377+
_attr[k] = v
383378

384379
self.attr = _attr
385380
else:
@@ -532,11 +527,12 @@ def outputs(self):
532527
else:
533528

534529
apply_arguments = self.op_def.input_args(*self.inputs, **self.node_def.attr)
535-
out_types_mt = self.op_def.out_meta_types(inputs=apply_arguments)
530+
out_types_mt = self.op_def.out_meta_types(
531+
inputs=apply_arguments, node_def=self.node_def
532+
)
536533

537534
mt_outs = tuple(
538-
o_type(self, i, var() if o_dtype is None else o_dtype)
539-
for i, (o_type, o_dtype) in enumerate(out_types_mt)
535+
o_type(self, i, o_dtype) for i, (o_type, o_dtype) in enumerate(out_types_mt)
540536
)
541537

542538
self._outputs = mt_outs
@@ -574,7 +570,15 @@ def reify(self):
574570
if isvar(self.node_def):
575571
return self
576572

577-
op_attrs, op_attrs_unreified = meta_reify_iter(self.node_def.attr)
573+
op_attrs, op_attrs_unreified = meta_reify_iter(
574+
# Only use NodeDef attrs that appear in the OpDef's call signature.
575+
# Other NodeDef attrs, like dtype and shape, can be computed.
576+
{
577+
k: v
578+
for k, v in self.node_def.attr.items()
579+
if k in self.op_def._apply_func_sig.parameters
580+
}
581+
)
578582

579583
if not (op_inputs_unreified or op_attrs_unreified or MetaSymbol.is_meta(self.name)):
580584

@@ -587,6 +591,8 @@ def reify(self):
587591
tf_out = self.op_def._apply_func(**apply_arguments)
588592
op_tf = tf_out.op
589593

594+
# TODO: Update NodeDef attrs?
595+
590596
assert op_tf is not None
591597
self._obj = op_tf
592598
return self.obj
@@ -623,14 +629,8 @@ def name(self):
623629

624630
if self.obj is not None and not isinstance(self.obj, Var):
625631
name = self.obj.name
626-
elif (
627-
self.op is not None
628-
and not isvar(self.op)
629-
and not isvar(self.op.name)
630-
and not isinstance(self.op.outputs, Sequence)
631-
):
632-
out_num = self.op.outputs.index(self)
633-
name = f"{self.op.name}:{out_num}"
632+
elif isinstance(getattr(self.op, "name", None), str) and not isvar(self.value_index):
633+
name = f"{self.op.name}:{self.value_index}"
634634
else:
635635
name = var()
636636

tests/tensorflow/test_meta.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
If you're debugging/running tests manually, it might help to simply
33
disable eager execution entirely:
44
5-
> tf.compat.v1.disable_eager_execution()
5+
tf.compat.v1.disable_eager_execution()
66
"""
77
import pytest
88
import numpy as np
@@ -172,6 +172,21 @@ def test_meta_basic():
172172
assert a_mt.shape.ndims == 2
173173
assert a_mt.shape == TFlowMetaTensorShape([1, 2])
174174

175+
# Make sure that names are properly inferred when there are no base objects
176+
# to reference
177+
with tf.Graph().as_default():
178+
one_mt = mt(1.0)
179+
log_mt = mt.log(one_mt)
180+
assert log_mt.name == 'Log:0'
181+
assert log_mt.dtype == tf.float32
182+
assert log_mt.op.outputs[0].dtype == tf.float32
183+
184+
log_mt._name = None
185+
one_mt._obj = None
186+
log_mt._obj = None
187+
assert log_mt.dtype == tf.float32
188+
assert log_mt.name == 'Log:0'
189+
175190

176191
@pytest.mark.usefixtures("run_with_tensorflow")
177192
@run_in_graph_mode
@@ -394,7 +409,6 @@ def test_nodedef():
394409

395410
assert 'compute_uv' in node_def_mt.attr
396411
assert 'full_matrices' in node_def_mt.attr
397-
assert 'T' not in node_def_mt.attr
398412

399413
# Some outputs use nodedef information; let's test those.
400414
norm_rv = mt.RandomStandardNormal(mean=0, stddev=1, shape=(1000,), dtype=tf.float32, name=var())

0 commit comments

Comments
 (0)