|
| 1 | +""" |
| 2 | +If you're debugging/running tests manually, it might help to simply |
| 3 | +disable eager execution entirely: |
| 4 | +
|
| 5 | + > tf.compat.v1.disable_eager_execution() |
| 6 | +""" |
1 | 7 | import pytest |
2 | 8 | import numpy as np |
3 | 9 |
|
|
14 | 20 | TFlowMetaOp, |
15 | 21 | TFlowMetaOpDef, |
16 | 22 | TFlowMetaNodeDef, |
17 | | - TFlowOpName, |
18 | 23 | MetaOpDefLibrary, |
19 | 24 | mt) |
20 | 25 |
|
21 | 26 | from tests.tensorflow import run_in_graph_mode |
22 | 27 | from tests.tensorflow.utils import assert_ops_equal |
23 | 28 |
|
24 | 29 |
|
25 | | -@pytest.mark.usefixtures("run_with_tensorflow") |
26 | | -def test_op_names(): |
27 | | - """Make sure equality is flexible for `Operation`/`OpDef` names.""" |
28 | | - # Against a string, only the distinct operator-name part matters |
29 | | - assert TFlowOpName('add_1') == 'add' |
30 | | - assert TFlowOpName('blah/add_1:0') == 'add' |
31 | | - assert TFlowOpName('blah/add_1:0') != 'add_1' |
32 | | - assert TFlowOpName('blah/add_1:0') != 'add:0' |
33 | | - # Unless it's the whole thing |
34 | | - assert TFlowOpName('blah/add_1:0') == 'blah/add_1:0' |
35 | | - |
36 | | - # Ignore namespaces |
37 | | - assert TFlowOpName('blah/add_1:0') == TFlowOpName('add:0') |
38 | | - assert TFlowOpName('blah/add_1:0') == TFlowOpName('agh/add_1:0') |
39 | | - # and "unique" operator names (for the same operator "type") |
40 | | - assert TFlowOpName('blah/add_1:0') == TFlowOpName('add_2:0') |
41 | | - # but not output numbers |
42 | | - assert TFlowOpName('blah/add_1:0') != TFlowOpName('blah/add:1') |
43 | | - |
44 | | - assert isinstance(mt(TFlowOpName('blah/add_1:0')), TFlowOpName) |
45 | | - |
46 | | - |
47 | 30 | @pytest.mark.usefixtures("run_with_tensorflow") |
48 | 31 | def test_meta_helper(): |
49 | 32 | """Make sure the helper/namespace emulator can find `OpDef`s and create their meta objects.""" |
@@ -79,14 +62,69 @@ def test_meta_eager(): |
79 | 62 |
|
80 | 63 | @pytest.mark.usefixtures("run_with_tensorflow") |
81 | 64 | @run_in_graph_mode |
82 | | -def test_meta_create(): |
| 65 | +def test_meta_basic(): |
| 66 | + |
| 67 | + var_mt = TFlowMetaTensor(var(), var(), var()) |
| 68 | + # It should generate a logic variable for the name and use from here on. |
| 69 | + var_name = var_mt.name |
| 70 | + assert isvar(var_name) |
| 71 | + assert var_mt.name is var_name |
| 72 | + # Same for a tensor shape |
| 73 | + var_shape = var_mt.shape |
| 74 | + assert isinstance(var_shape, TFlowMetaTensorShape) |
| 75 | + assert isvar(var_shape.dims) |
| 76 | + |
| 77 | + # This essentially logic-variabled tensor should not reify; it should |
| 78 | + # create a distinct/new meta object that's either equal to the original |
| 79 | + # meta object or partially reified. |
| 80 | + assert var_mt.reify() == var_mt |
| 81 | + |
| 82 | + # This operator is reifiable |
| 83 | + # NOTE: Const objects are automatically created for the constant inputs, so |
| 84 | + # we need to do this in a new graph to make sure that their auto-generated |
| 85 | + # names are consistent throughout runs. |
| 86 | + with tf.Graph().as_default() as test_graph: |
| 87 | + test_op = TFlowMetaOp(mt.Add, TFlowMetaNodeDef('Add', 'Add', {}), [1, 0]) |
| 88 | + |
| 89 | + # This tensor has an "unknown"/logic variable output index and dtype, but, |
| 90 | + # since the operator fully specifies it, reification should still work. |
| 91 | + var_mt = TFlowMetaTensor(test_op, var(), var()) |
| 92 | + |
| 93 | + # This should be partially reified |
| 94 | + var_tf = var_mt.reify() |
| 95 | + |
| 96 | + assert isinstance(var_tf, tf.Tensor) |
| 97 | + |
| 98 | + # These shouldn't be equal, since `var_mt` has logic variables for |
| 99 | + # output index and dtype. (They should be unifiable, though.) |
| 100 | + assert mt(var_tf) != var_mt |
| 101 | + |
| 102 | + # NOTE: The operator name specified by the meta NodeDef *can* be |
| 103 | + # different from the reified TF tensor (e.g. when meta objects are |
| 104 | + # created/reified within a graph already using the NodeDef-specified |
| 105 | + # name). |
| 106 | + # |
| 107 | + # TODO: We could search for existing TF objects in the current graph by |
| 108 | + # name and raise exceptions when the desired meta information and name |
| 109 | + # do not correspond--effectively making the meta object impossible to |
| 110 | + # reify in said graph. |
| 111 | + |
| 112 | + # Next, we convert an existing TF object into a meta object |
| 113 | + # and make sure everything corresponds between the two. |
83 | 114 | N = 100 |
84 | 115 | X = np.vstack([np.random.randn(N), np.ones(N)]).T |
| 116 | + |
85 | 117 | X_tf = tf.convert_to_tensor(X) |
86 | | - X_mt = mt(X) |
| 118 | + |
| 119 | + with tf.Graph().as_default() as test_graph: |
| 120 | + X_mt = mt(X) |
87 | 121 |
|
88 | 122 | assert isinstance(X_mt, TFlowMetaTensor) |
89 | | - assert X_mt.op.obj.name.startswith('Const') |
| 123 | + assert X_mt.op.obj.name == 'Const' |
| 124 | + assert not hasattr(X_mt, '_name') |
| 125 | + assert X_mt.name == 'Const:0' |
| 126 | + assert X_mt._name == 'Const:0' |
| 127 | + |
90 | 128 | # Make sure `reify` returns the cached base object. |
91 | 129 | assert X_mt.reify() is X_mt.obj |
92 | 130 | assert isinstance(X_mt.reify(), tf.Tensor) |
@@ -373,7 +411,12 @@ def test_nodedef(): |
373 | 411 | # `ytest_mt.inputs` should have two `.attr` values that are Python |
374 | 412 | # primitives (i.e. int and bool); these shouldn't get metatized and break |
375 | 413 | # our ability to reconstruct the object from its rator + rands. |
376 | | - assert y_test_mt == y_test_mt.op.op_def(*y_test_mt.inputs) |
| 414 | + y_test_new_mt = y_test_mt.op.op_def(*y_test_mt.inputs) |
| 415 | + |
| 416 | + # We're changing this simply so we can use == |
| 417 | + y_test_new_mt.op.node_def.name = 'y' |
| 418 | + |
| 419 | + assert y_test_mt == y_test_new_mt |
377 | 420 |
|
378 | 421 |
|
379 | 422 | @pytest.mark.usefixtures("run_with_tensorflow") |
|
0 commit comments