Skip to content

Commit ef25173

Browse files
Stop metatizing all node_def.attr values
Fixes #64.
1 parent 2fbc860 commit ef25173

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import tensorflow as tf
77
import tensorflow_probability as tfp
88

9-
from contextlib import suppress
10-
119
from inspect import Parameter, Signature
1210

1311
from collections import OrderedDict, UserString
@@ -183,7 +181,7 @@ def __hash__(self):
183181
def _metatize_tf_object(obj):
184182
try:
185183
obj = tf.convert_to_tensor(obj)
186-
except TypeError:
184+
except (TypeError, ValueError):
187185
raise ValueError("Could not find a TensorFlow MetaSymbol class for {obj}")
188186

189187
if isinstance(obj, tf.Tensor):
@@ -368,9 +366,8 @@ def _protobuf_convert(cls, k, v):
368366
from google.protobuf.json_format import MessageToDict
369367
MessageToDict(obj, use_integers_for_enums=True)
370368
"""
371-
372369
if k == "shape":
373-
return tensor_shape.as_shape(v.shape)
370+
return metatize(tensor_shape.as_shape(v.shape))
374371
elif k == "dtype":
375372
return tf.as_dtype(v.type).name
376373
elif k == "value":
@@ -406,12 +403,6 @@ def __init__(self, op, name, attr, obj=None):
406403
continue
407404

408405
if k != "T" and k in op_param_names:
409-
# XXX: We can't let `metatize` convert NumPy values;
410-
# otherwise, we'll loop endlessly on "Const" Ops.
411-
if k != "value":
412-
with suppress(ValueError):
413-
v = metatize(v)
414-
415406
self.attr[k] = v
416407
else:
417408
self.attr = attr

tests/tensorflow/test_meta.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,25 @@ def test_nodedef():
337337
norm_rv = mt.RandomStandardNormal(mean=0, stddev=1, shape=(1000,), dtype=tf.float32, name=var())
338338
assert isinstance(norm_rv, TFlowMetaTensor)
339339
assert norm_rv.dtype == tf.float32
340+
341+
# We shouldn't be metatizing all parsed `node_def.attr` values; otherwise,
342+
# we won't be able to reconstruct corresponding meta Ops using their meta
343+
# OpDefs and inputs.
344+
x_test = tf.constant([1.8, 2.2], dtype=tf.float32)
345+
y_test = tf.dtypes.cast(x_test, dtype=tf.int32, name="y")
346+
y_test_mt = mt(y_test)
347+
348+
# `ytest_mt.inputs` should have two `.attr` values that are Python
349+
# primitives (i.e. int and bool); these shouldn't get metatized and break
350+
# our ability to reconstruct the object from its rator + rands.
351+
assert y_test_mt == y_test_mt.op.op_def(*y_test_mt.inputs)
352+
353+
354+
@pytest.mark.usefixtures("run_with_tensorflow")
355+
@run_in_graph_mode
356+
def test_metatize():
357+
class CustomClass(object):
358+
pass
359+
360+
with pytest.raises(ValueError):
361+
mt(CustomClass())

0 commit comments

Comments
 (0)