Skip to content

Commit 9777aac

Browse files
Merge pull request #65 from brandonwillard/fix-bad-attr-metatizing
Stop metatizing all node_def.attr values
2 parents 2fbc860 + ff60cbd commit 9777aac

File tree

2 files changed

+71
-15
lines changed

2 files changed

+71
-15
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import tensorflow as tf
77
import tensorflow_probability as tfp
88

9-
from contextlib import suppress
10-
119
from inspect import Parameter, Signature
1210

1311
from collections import OrderedDict, UserString
@@ -88,25 +86,26 @@ def make_opdef_sig(cls, opdef):
8886
params[name] = new_param
8987

9088
else:
91-
params = []
92-
for i_name, i_type in input_args:
89+
for i_name, i_type in input_args.items():
9390
p = Parameter(i_name, Parameter.POSITIONAL_OR_KEYWORD, annotation=i_type)
9491
params[i_name] = p
9592

9693
# These are the ambiguities we're attempting to overcome
9794
# with the `tf.raw_ops` functions above.
98-
for a_name, a_type in attrs:
95+
for a_name, a_type in attrs.items():
96+
9997
if a_name == "T":
10098
# This is a type value that will most likely be inferred
10199
# from/by the inputs.
102100
# TODO: We could check for an `allowed_values` attribute.
103101
continue
102+
104103
p = Parameter(
105104
a_name,
106105
Parameter.POSITIONAL_OR_KEYWORD,
107106
# TODO: We could use the `default_value`
108107
# attribute.
109-
default=None,
108+
default=Parameter.empty,
110109
annotation=a_type,
111110
)
112111
params[a_name] = p
@@ -183,7 +182,7 @@ def __hash__(self):
183182
def _metatize_tf_object(obj):
184183
try:
185184
obj = tf.convert_to_tensor(obj)
186-
except TypeError:
185+
except (TypeError, ValueError):
187186
raise ValueError("Could not find a TensorFlow MetaSymbol class for {obj}")
188187

189188
if isinstance(obj, tf.Tensor):
@@ -368,9 +367,8 @@ def _protobuf_convert(cls, k, v):
368367
from google.protobuf.json_format import MessageToDict
369368
MessageToDict(obj, use_integers_for_enums=True)
370369
"""
371-
372370
if k == "shape":
373-
return tensor_shape.as_shape(v.shape)
371+
return metatize(tensor_shape.as_shape(v.shape))
374372
elif k == "dtype":
375373
return tf.as_dtype(v.type).name
376374
elif k == "value":
@@ -406,12 +404,6 @@ def __init__(self, op, name, attr, obj=None):
406404
continue
407405

408406
if k != "T" and k in op_param_names:
409-
# XXX: We can't let `metatize` convert NumPy values;
410-
# otherwise, we'll loop endlessly on "Const" Ops.
411-
if k != "value":
412-
with suppress(ValueError):
413-
v = metatize(v)
414-
415407
self.attr[k] = v
416408
else:
417409
self.attr = attr

tests/tensorflow/test_meta.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TFlowMetaOpDef,
1717
TFlowMetaNodeDef,
1818
TFlowOpName,
19+
MetaOpDefLibrary,
1920
mt)
2021

2122
from tests.tensorflow import run_in_graph_mode
@@ -321,6 +322,47 @@ def test_inputs_remapping():
321322
assert z_mt.inputs[1].obj == z.op.inputs[2]
322323

323324

325+
@pytest.mark.usefixtures("run_with_tensorflow")
326+
def test_opdef_sig():
327+
"""Make sure we can construct an `inspect.Signature` object for a protobuf OpDef when its corresponding function isn't present in `tf.raw_ops`."""
328+
from tensorflow.core.framework import op_def_pb2
329+
330+
custom_opdef_tf = op_def_pb2.OpDef()
331+
custom_opdef_tf.name = "MyOpDef"
332+
333+
arg1_tf = op_def_pb2.OpDef.ArgDef()
334+
arg1_tf.name = "arg1"
335+
arg1_tf.type_attr = "T"
336+
337+
arg2_tf = op_def_pb2.OpDef.ArgDef()
338+
arg2_tf.name = "arg2"
339+
arg2_tf.type_attr = "T"
340+
341+
custom_opdef_tf.input_arg.extend([arg1_tf, arg2_tf])
342+
343+
attr1_tf = op_def_pb2.OpDef.AttrDef()
344+
attr1_tf.name = "T"
345+
attr1_tf.type = "type"
346+
347+
attr2_tf = op_def_pb2.OpDef.AttrDef()
348+
attr2_tf.name = "axis"
349+
attr2_tf.type = "int"
350+
attr2_tf.default_value.i = 1
351+
352+
custom_opdef_tf.attr.extend([attr1_tf, attr2_tf])
353+
354+
opdef_sig = MetaOpDefLibrary.make_opdef_sig(custom_opdef_tf)
355+
356+
import inspect
357+
# These are standard inputs
358+
assert opdef_sig.parameters['arg1'].default == inspect._empty
359+
assert opdef_sig.parameters['arg2'].default == inspect._empty
360+
# These are attributes that are sometimes required by the OpDef
361+
assert opdef_sig.parameters['axis'].default == inspect._empty
362+
# The obligatory tensor name parameter
363+
assert opdef_sig.parameters['name'].default is None
364+
365+
324366
@pytest.mark.usefixtures("run_with_tensorflow")
325367
@run_in_graph_mode
326368
def test_nodedef():
@@ -337,3 +379,25 @@ def test_nodedef():
337379
norm_rv = mt.RandomStandardNormal(mean=0, stddev=1, shape=(1000,), dtype=tf.float32, name=var())
338380
assert isinstance(norm_rv, TFlowMetaTensor)
339381
assert norm_rv.dtype == tf.float32
382+
383+
# We shouldn't be metatizing all parsed `node_def.attr` values; otherwise,
384+
# we won't be able to reconstruct corresponding meta Ops using their meta
385+
# OpDefs and inputs.
386+
x_test = tf.constant([1.8, 2.2], dtype=tf.float32)
387+
y_test = tf.dtypes.cast(x_test, dtype=tf.int32, name="y")
388+
y_test_mt = mt(y_test)
389+
390+
# `ytest_mt.inputs` should have two `.attr` values that are Python
391+
# primitives (i.e. int and bool); these shouldn't get metatized and break
392+
# our ability to reconstruct the object from its rator + rands.
393+
assert y_test_mt == y_test_mt.op.op_def(*y_test_mt.inputs)
394+
395+
396+
@pytest.mark.usefixtures("run_with_tensorflow")
397+
@run_in_graph_mode
398+
def test_metatize():
399+
class CustomClass(object):
400+
pass
401+
402+
with pytest.raises(ValueError):
403+
mt(CustomClass())

0 commit comments

Comments
 (0)