Skip to content

Commit e58ab7f

Browse files
Remove TFLowOpName
1 parent e1cfb89 commit e58ab7f

File tree

3 files changed

+80
-82
lines changed

3 files changed

+80
-82
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 3 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from inspect import Parameter, Signature
1010

11-
from collections import OrderedDict, UserString
11+
from collections import OrderedDict
1212
from collections.abc import Sequence
1313

1414
from functools import partial
@@ -144,51 +144,6 @@ def get_op_info(cls, opdef):
144144
op_def_lib = MetaOpDefLibrary()
145145

146146

147-
class TFlowOpName(UserString):
148-
"""A wrapper for Tensor names.
149-
150-
TF `Operation` names, and the variables that result from them, cannot be
151-
compared directly due to their uniqueness (within a TF namespace/scope).
152-
153-
This wrapper class ignores those TF distinctions during string comparison.
154-
"""
155-
156-
def __init__(self, s):
157-
super().__init__(s)
158-
159-
if isinstance(s, type(self)):
160-
self._scope_op = s._scope_op
161-
self._in_idx = s._in_idx
162-
self._scope = s._scope
163-
self._op_name = s._op_name
164-
self._unique_name = s._unique_name
165-
else:
166-
self._scope_op, _, self._in_idx = self.data.partition(":")
167-
self._scope, _, self._op_name = self._scope_op.rpartition("/")
168-
self._unique_name = self._op_name.split("_", 1)[0]
169-
170-
def __eq__(self, other):
171-
if self is other:
172-
return True
173-
174-
if isinstance(other, str):
175-
return self.data == other or self._unique_name == other
176-
177-
if type(self) != type(other):
178-
return False
179-
180-
return (
181-
self._unique_name.lower() == other._unique_name.lower()
182-
and self._in_idx == other._in_idx
183-
)
184-
185-
def __hash__(self):
186-
return hash((self._unique_name, self._in_idx))
187-
188-
189-
_metatize.add((TFlowOpName,), lambda x: x)
190-
191-
192147
def _metatize_tf_object(obj):
193148
try:
194149
obj = tf.convert_to_tensor(obj)
@@ -406,7 +361,7 @@ def __init__(self, op, name, attr, obj=None):
406361
super().__init__(obj=obj)
407362
self.op = metatize(op)
408363
assert name is not None
409-
self.name = name if isvar(name) else TFlowOpName(name)
364+
self.name = name if isvar(name) else str(name)
410365

411366
if not isvar(attr):
412367
# We want to limit the attributes we'll consider to those that show
@@ -549,7 +504,7 @@ def name(self):
549504
if isvar(self.node_def):
550505
self._name = var()
551506
else:
552-
self._name = TFlowOpName(self.node_def.name)
507+
self._name = str(self.node_def.name)
553508

554509
return self._name
555510

tests/tensorflow/test_meta.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
"""
2+
If you're debugging/running tests manually, it might help to simply
3+
disable eager execution entirely:
4+
5+
> tf.compat.v1.disable_eager_execution()
6+
"""
17
import pytest
28
import numpy as np
39

@@ -14,36 +20,13 @@
1420
TFlowMetaOp,
1521
TFlowMetaOpDef,
1622
TFlowMetaNodeDef,
17-
TFlowOpName,
1823
MetaOpDefLibrary,
1924
mt)
2025

2126
from tests.tensorflow import run_in_graph_mode
2227
from tests.tensorflow.utils import assert_ops_equal
2328

2429

25-
@pytest.mark.usefixtures("run_with_tensorflow")
26-
def test_op_names():
27-
"""Make sure equality is flexible for `Operation`/`OpDef` names."""
28-
# Against a string, only the distinct operator-name part matters
29-
assert TFlowOpName('add_1') == 'add'
30-
assert TFlowOpName('blah/add_1:0') == 'add'
31-
assert TFlowOpName('blah/add_1:0') != 'add_1'
32-
assert TFlowOpName('blah/add_1:0') != 'add:0'
33-
# Unless it's the whole thing
34-
assert TFlowOpName('blah/add_1:0') == 'blah/add_1:0'
35-
36-
# Ignore namespaces
37-
assert TFlowOpName('blah/add_1:0') == TFlowOpName('add:0')
38-
assert TFlowOpName('blah/add_1:0') == TFlowOpName('agh/add_1:0')
39-
# and "unique" operator names (for the same operator "type")
40-
assert TFlowOpName('blah/add_1:0') == TFlowOpName('add_2:0')
41-
# but not output numbers
42-
assert TFlowOpName('blah/add_1:0') != TFlowOpName('blah/add:1')
43-
44-
assert isinstance(mt(TFlowOpName('blah/add_1:0')), TFlowOpName)
45-
46-
4730
@pytest.mark.usefixtures("run_with_tensorflow")
4831
def test_meta_helper():
4932
"""Make sure the helper/namespace emulator can find `OpDef`s and create their meta objects."""
@@ -79,14 +62,69 @@ def test_meta_eager():
7962

8063
@pytest.mark.usefixtures("run_with_tensorflow")
8164
@run_in_graph_mode
82-
def test_meta_create():
65+
def test_meta_basic():
66+
67+
var_mt = TFlowMetaTensor(var(), var(), var())
68+
# It should generate a logic variable for the name and use from here on.
69+
var_name = var_mt.name
70+
assert isvar(var_name)
71+
assert var_mt.name is var_name
72+
# Same for a tensor shape
73+
var_shape = var_mt.shape
74+
assert isinstance(var_shape, TFlowMetaTensorShape)
75+
assert isvar(var_shape.dims)
76+
77+
# This essentially logic-variabled tensor should not reify; it should
78+
# create a distinct/new meta object that's either equal to the original
79+
# meta object or partially reified.
80+
assert var_mt.reify() == var_mt
81+
82+
# This operator is reifiable
83+
# NOTE: Const objects are automatically created for the constant inputs, so
84+
# we need to do this in a new graph to make sure that their auto-generated
85+
# names are consistent throughout runs.
86+
with tf.Graph().as_default() as test_graph:
87+
test_op = TFlowMetaOp(mt.Add, TFlowMetaNodeDef('Add', 'Add', {}), [1, 0])
88+
89+
# This tensor has an "unknown"/logic variable output index and dtype, but,
90+
# since the operator fully specifies it, reification should still work.
91+
var_mt = TFlowMetaTensor(test_op, var(), var())
92+
93+
# This should be partially reified
94+
var_tf = var_mt.reify()
95+
96+
assert isinstance(var_tf, tf.Tensor)
97+
98+
# These shouldn't be equal, since `var_mt` has logic variables for
99+
# output index and dtype. (They should be unifiable, though.)
100+
assert mt(var_tf) != var_mt
101+
102+
# NOTE: The operator name specified by the meta NodeDef *can* be
103+
# different from the reified TF tensor (e.g. when meta objects are
104+
# created/reified within a graph already using the NodeDef-specified
105+
# name).
106+
#
107+
# TODO: We could search for existing TF objects in the current graph by
108+
# name and raise exceptions when the desired meta information and name
109+
# do not correspond--effectively making the meta object impossible to
110+
# reify in said graph.
111+
112+
# Next, we convert an existing TF object into a meta object
113+
# and make sure everything corresponds between the two.
83114
N = 100
84115
X = np.vstack([np.random.randn(N), np.ones(N)]).T
116+
85117
X_tf = tf.convert_to_tensor(X)
86-
X_mt = mt(X)
118+
119+
with tf.Graph().as_default() as test_graph:
120+
X_mt = mt(X)
87121

88122
assert isinstance(X_mt, TFlowMetaTensor)
89-
assert X_mt.op.obj.name.startswith('Const')
123+
assert X_mt.op.obj.name == 'Const'
124+
assert not hasattr(X_mt, '_name')
125+
assert X_mt.name == 'Const:0'
126+
assert X_mt._name == 'Const:0'
127+
90128
# Make sure `reify` returns the cached base object.
91129
assert X_mt.reify() is X_mt.obj
92130
assert isinstance(X_mt.reify(), tf.Tensor)
@@ -373,7 +411,12 @@ def test_nodedef():
373411
# `ytest_mt.inputs` should have two `.attr` values that are Python
374412
# primitives (i.e. int and bool); these shouldn't get metatized and break
375413
# our ability to reconstruct the object from its rator + rands.
376-
assert y_test_mt == y_test_mt.op.op_def(*y_test_mt.inputs)
414+
y_test_new_mt = y_test_mt.op.op_def(*y_test_mt.inputs)
415+
416+
# We're changing this simply so we can use ==
417+
y_test_new_mt.op.node_def.name = 'y'
418+
419+
assert y_test_mt == y_test_new_mt
377420

378421

379422
@pytest.mark.usefixtures("run_with_tensorflow")

tests/tensorflow/test_unify.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from unification import unify, reify, var
66

7-
from symbolic_pymc.tensorflow.meta import (TFlowOpName, mt, TFlowMetaTensorShape, TFlowOpName)
7+
from symbolic_pymc.tensorflow.meta import (mt, TFlowMetaTensorShape)
88
from symbolic_pymc.etuple import (ExpressionTuple, etuple, etuplize)
99

1010
from tests.tensorflow import run_in_graph_mode
@@ -15,8 +15,6 @@
1515
@run_in_graph_mode
1616
def test_etuple_term():
1717

18-
str_type = TFlowOpName("blah")
19-
assert etuplize(str_type, return_bad_args=True) == str_type
2018
assert etuplize("blah", return_bad_args=True) == "blah"
2119

2220
a = tf.compat.v1.placeholder(tf.float64, name='a')
@@ -42,13 +40,13 @@ def test_etuple_term():
4240
assert test_e[2] is a_mt.op.node_def.attr['shape']
4341

4442
test_e._eval_obj = ExpressionTuple.null
45-
a_evaled = test_e.eval_obj
43+
with tf.Graph().as_default():
44+
a_evaled = test_e.eval_obj
4645
assert all([a == b for a, b in zip(a_evaled.rands(), a_mt.rands())])
4746

4847
a_reified = a_evaled.reify()
4948
assert isinstance(a_reified, tf.Tensor)
5049
assert a_reified.shape.dims is None
51-
assert TFlowOpName(a_reified.name) == TFlowOpName(a.name)
5250

5351
e2 = mt.add(a, b)
5452
e2_et = etuplize(e2)
@@ -75,8 +73,10 @@ def test_basic_unify_reify():
7573
test_base_res = test_reify_res.reify()
7674
assert isinstance(test_base_res, tf.Tensor)
7775

78-
expected_res = tf.add(tf.constant(1, dtype=tf.float64),
79-
tf.constant(2, dtype=tf.float64) * a)
76+
with tf.Graph().as_default():
77+
a = tf.compat.v1.placeholder(tf.float64, name='a')
78+
expected_res = tf.add(tf.constant(1, dtype=tf.float64),
79+
tf.multiply(tf.constant(2, dtype=tf.float64), a))
8080
assert_ops_equal(test_base_res, expected_res)
8181

8282
# Simply make sure that unification succeeds

0 commit comments

Comments
 (0)