Skip to content

Commit 9b81826

Browse files
Merge pull request #68 from brandonwillard/update-tf-dep-versions
Update pinned TF and TFP to latest versions
2 parents 9777aac + 159c528 commit 9b81826

File tree

6 files changed

+58
-17
lines changed

6 files changed

+58
-17
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ ignore-mixin-members=yes
264264
# (useful for modules/projects where namespaces are manipulated during runtime
265265
# and thus existing member attributes cannot be deduced by static analysis. It
266266
# supports qualified module names, as well as Unix pattern matching.
267-
ignored-modules=
267+
ignored-modules=tensorflow.core.framework,tensorflow.python.framework
268268

269269
# List of classes names for which member attributes should not be checked
270270
# (useful for classes with attributes dynamically set). This supports can work

requirements.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
scipy>=1.2.0
22
Theano>=1.0.4
3-
tf-nightly-2.0-preview==2.0.0.dev20190606
4-
tfp-nightly==0.8.0.dev20190705
3+
gast==0.2.2
4+
tf-nightly-2.0-preview==2.0.0.dev20190908
5+
tensorflow-estimator-2.0-preview==1.14.0.dev2019090801
6+
tfp-nightly==0.9.0.dev20190908
57
pymc3>=3.6
68
pymc4 @ git+https://github.com/pymc-devs/pymc4.git@master#egg=pymc4-0.0.1
79
multipledispatch>=0.6.0

symbolic_pymc/tensorflow/meta.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,10 @@ def __eq__(self, other):
170170
if type(self) != type(other):
171171
return False
172172

173-
return self._unique_name == other._unique_name and self._in_idx == other._in_idx
173+
return (
174+
self._unique_name.lower() == other._unique_name.lower()
175+
and self._in_idx == other._in_idx
176+
)
174177

175178
def __hash__(self):
176179
return hash((self._unique_name, self._in_idx))
@@ -309,7 +312,7 @@ def __call__(self, *args, **kwargs):
309312
# Get the `OpDef`-instantiating parameters and call them a "node_def".
310313
node_attr = {a.name: apply_arguments.get(a.name, a) for a in self.obj.attr}
311314

312-
op_name = op_kwargs.get("name", self.obj.name.lower())
315+
op_name = op_kwargs.get("name", self.obj.name)
313316

314317
# input_arg_names = [(getattr(a, 'name', None), i.name)
315318
# for a, i in zip(args, self.obj.input_arg)]
@@ -453,7 +456,7 @@ def __init__(self, op_def, node_def, inputs, name=None, outputs=None, obj=None):
453456

454457
if isinstance(name, (str, TFlowOpName)) or name is None:
455458
if name is None:
456-
name = op_def.obj.name.lower()
459+
name = op_def.obj.name
457460
# from tensorflow.python.framework import ops
458461
# if name and name[-1] == "/":
459462
# name = ops._name_from_scope_name(str(name))
@@ -512,7 +515,7 @@ def outputs(self):
512515
value_index=i,
513516
shape=var(),
514517
name=(
515-
TFlowOpName(f"{self.name.lower()}:{i}")
518+
TFlowOpName(f"{self.name}:{i}")
516519
if isinstance(self.name, (str, TFlowOpName))
517520
else var()
518521
),
@@ -536,8 +539,6 @@ def default_output(self):
536539

537540
mt_outs = self.outputs
538541

539-
if isvar(mt_outs):
540-
out_var = var()
541542
if len(mt_outs) == 1:
542543
out_var = mt_outs[0]
543544
else:
@@ -553,8 +554,13 @@ def reify(self):
553554
# tt_op = self.op.reify()
554555
# if not self.is_meta(tt_op):
555556
op_inputs, op_inputs_unreified = _meta_reify_iter(self.inputs)
557+
558+
if isvar(self.node_def):
559+
return self
560+
556561
op_attrs, op_attrs_unreified = _meta_reify_iter(self.node_def.attr)
557-
if not op_inputs_unreified and not op_attrs_unreified and not MetaSymbol.is_meta(self.name):
562+
563+
if not (op_inputs_unreified or op_attrs_unreified or MetaSymbol.is_meta(self.name)):
558564

559565
# We have to use a primitive string or TF will complain.
560566
name = self.name
@@ -644,22 +650,24 @@ def operator(self):
644650
def inputs(self):
645651
"""Return the tensor's inputs/rands.
646652
647-
NOTE: These inputs differ from `self.op.inputs` in that contain
653+
NOTE: These inputs differ from `self.op.inputs` in that they contain
648654
the `node_def` parameters, as well.
649655
In other words, these can be used to recreate this object (per
650656
the meta object spec).
651657
"""
652658
if self.op is not None and not isvar(self.op):
653659
input_args = self.op.op_def.input_args(
654-
*self.op.inputs, name=self.op.name, **self.op.node_def.attr
660+
*self.op.inputs,
661+
name=self.op.name if not isvar(self.op.name) else None,
662+
**self.op.node_def.attr,
655663
)
656664
return tuple(input_args.values())
657665

658666
def reify(self):
659667
if self.obj is not None and not isinstance(self.obj, Var):
660668
return self.obj
661669

662-
if not self.op:
670+
if (not self.op) or isvar(self.op):
663671
op_res = super().reify()
664672
return op_res
665673

@@ -793,7 +801,7 @@ def __call__(self, x):
793801
@classmethod
794802
def find_opdef(cls, name):
795803
"""Attempt to create a meta `OpDef` for a given TF function/`Operation` name."""
796-
raw_op_name = op_def_lib.lower_op_name_to_raw.get(name, name)
804+
raw_op_name = op_def_lib.lower_op_name_to_raw.get(name.lower(), name)
797805
op_def = op_def_registry.get_registered_ops()[raw_op_name]
798806

799807
if op_def is not None:

tests/tensorflow/test_meta.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from unification import var, isvar
1010

11+
from symbolic_pymc.meta import MetaSymbol
1112
from symbolic_pymc.tensorflow.meta import (TFlowMetaTensor,
1213
TFlowMetaTensorShape,
1314
TFlowMetaConstant,
@@ -166,6 +167,35 @@ def test_meta_create():
166167
with pytest.raises(TypeError):
167168
TFlowMetaTensor('float64', 'Add', name='q__')
168169

170+
@pytest.mark.usefixtures("run_with_tensorflow")
171+
@run_in_graph_mode
172+
def test_meta_Op():
173+
174+
from tensorflow.python.eager.context import graph_mode
175+
176+
177+
with graph_mode():
178+
t1_tf = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]])
179+
t2_tf = tf.convert_to_tensor([[7, 8, 9], [10, 11, 12]])
180+
test_out_tf = tf.concat([t1_tf, t2_tf], 0)
181+
182+
# TODO: Without explicit conversion, each element within these arrays gets
183+
# converted to a `Tensor` by `metatize`. That doesn't seem very
184+
# reasonable. Likewise, the `0` gets converted, but it probably shouldn't be.
185+
test_op = TFlowMetaOp(mt.Concat, var(), [[t1_tf, t2_tf], 0])
186+
187+
# Make sure we converted lists to tuples
188+
assert isinstance(test_op.inputs, tuple)
189+
assert isinstance(test_op.inputs[0], tuple)
190+
191+
test_op = TFlowMetaOp(mt.Concat, var(), [[t1_tf, t2_tf], 0], outputs=[test_out_tf])
192+
193+
# NodeDef is a logic variable, so this shouldn't actually reify.
194+
assert MetaSymbol.is_meta(test_op.reify())
195+
assert isinstance(test_op.outputs, tuple)
196+
assert MetaSymbol.is_meta(test_op.outputs[0])
197+
198+
169199

170200
@pytest.mark.usefixtures("run_with_tensorflow")
171201
def test_meta_lvars():

tests/tensorflow/test_unify.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def test_basic_unify_reify():
7070
test_base_res = test_reify_res.reify()
7171
assert isinstance(test_base_res, tf.Tensor)
7272

73-
expected_res = (tf.constant(1, dtype=tf.float64) +
74-
tf.constant(2, dtype=tf.float64) * a)
73+
expected_res = tf.add(tf.constant(1, dtype=tf.float64),
74+
tf.constant(2, dtype=tf.float64) * a)
7575
assert_ops_equal(test_base_res, expected_res)
7676

7777
# Simply make sure that unification succeeds
@@ -94,7 +94,7 @@ def test_sexp_unify_reify():
9494
y = tf.compat.v1.placeholder(tf.float64, name='y',
9595
shape=tf.TensorShape([None, 1]))
9696

97-
z = tf.matmul(A, x + y)
97+
z = tf.matmul(A, tf.add(x, y))
9898

9999
z_sexp = etuplize(z)
100100

tests/tensorflow/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import tensorflow as tf
44

5+
56
def assert_ops_equal(a, b, compare_fn=lambda a, b: a.op.type == b.op.type):
67
if hasattr(a, 'op') or hasattr(b, 'op'):
78
assert hasattr(a, 'op') and hasattr(b, 'op')

0 commit comments

Comments
 (0)