77
88from contextlib import redirect_stdout
99
10- from tensorflow . python . eager . context import graph_mode
10+ from unification import var , Var
1111
12+ from symbolic_pymc .tensorflow .meta import mt
1213from symbolic_pymc .tensorflow .printing import tf_dprint
1314
1415from tests .tensorflow import run_in_graph_mode
@@ -23,11 +24,7 @@ def test_eager_mode():
2324 X_tf = tf .convert_to_tensor (X )
2425
2526 with pytest .raises (ValueError ):
26- tf_dprint (X_tf )
27-
28- with graph_mode ():
29- X_tf = tf .convert_to_tensor (X )
30- tf_dprint (X_tf )
27+ _ = tf_dprint (X_tf )
3128
3229
3330@run_in_graph_mode
@@ -47,17 +44,39 @@ def test_ascii_printing():
4744 tf_dprint (z )
4845
4946 expected_out = textwrap .dedent ('''
50- Tensor(MatMul):0,\t shape=[None, 1]\t "A_dot:0"
51- | Op(MatMul)\t "A_dot"
52- | | Tensor(Placeholder):0,\t shape=[None, None]\t "A:0"
53- | | Tensor(Add):0,\t shape=[None, 1]\t "x_p_y:0"
54- | | | Op(Add)\t "x_p_y"
55- | | | | Tensor(Mul):0, shape=[None, 1] "y:0"
56- | | | | | Op(Mul)\t "y"
57- | | | | | | Tensor(Const):0,\t shape=[]\t "y/x:0"
58- | | | | | | Tensor(Placeholder):0,\t shape=[None, 1]\t "x:0"
59- | | | | Tensor(Mul):0,\t shape=[None, 1]\t "y:0"
60- | | | | | ...
47+ Tensor(MatMul):0,\t dtype=float32,\t shape=[None, 1],\t "A_dot:0"
48+ | Tensor(Placeholder):0,\t dtype=float32,\t shape=[None, None],\t "A:0"
49+ | Tensor(Add):0,\t dtype=float32,\t shape=[None, 1],\t "x_p_y:0"
50+ | | Tensor(Mul):0,\t dtype=float32,\t shape=[None, 1],\t "y:0"
51+ | | | Tensor(Const):0,\t dtype=float32,\t shape=[],\t "y/x:0"
52+ | | | | 1.
53+ | | | Tensor(Placeholder):0,\t dtype=float32,\t shape=[None, 1],\t "x:0"
54+ | | Tensor(Mul):0,\t dtype=float32,\t shape=[None, 1],\t "y:0"
55+ | | | ...
56+ ''' )
57+
58+ assert std_out .getvalue () == expected_out .lstrip ()
59+
60+ std_out = io .StringIO ()
61+ with tf .Graph ().as_default (), redirect_stdout (std_out ):
62+ Var ._id = 0
63+ tt_lv_inputs_mt = mt .Tensor (mt .Operation (var (), var (), var ()), 0 , var ())
64+ tt_const_lv_nodedef_mt = mt .Tensor (mt .Operation (mt .Const .op_def , var (), ()), 0 , var ())
65+ tt_lv_op_mt = mt .Tensor (var (), 0 , var ())
66+ test_mt = mt (1 ) + tt_lv_inputs_mt + tt_const_lv_nodedef_mt + tt_lv_op_mt
67+ tf_dprint (test_mt )
68+
69+ expected_out = textwrap .dedent ('''
70+ Tensor(AddV2):0,\t dtype=int32,\t shape=~_11,\t "add:0"
71+ | Tensor(AddV2):0,\t dtype=int32,\t shape=~_12,\t "add:0"
72+ | | Tensor(AddV2):0,\t dtype=int32,\t shape=~_13,\t "add:0"
73+ | | | Tensor(Const):0,\t dtype=int32,\t shape=[],\t "Const:0"
74+ | | | | 1
75+ | | | Tensor(~_15):0,\t dtype=~_3,\t shape=~_14,\t "~_17"
76+ | | | | ~_2
77+ | | Tensor(Const):0,\t dtype=~_5,\t shape=~_18,\t "~_20"
78+ | | | ~_4
79+ | Tensor(~_6):0,\t dtype=~_7,\t shape=~_21,\t "~_22"
6180 ''' )
6281
6382 assert std_out .getvalue () == expected_out .lstrip ()
@@ -73,6 +92,49 @@ def test_unknown_shape():
7392 with redirect_stdout (std_out ):
7493 tf_dprint (A )
7594
76- expected_out = 'Tensor(Placeholder):0,\t shape=Unknown\t "A:0"\n '
95+ expected_out = 'Tensor(Placeholder):0,\t dtype=float64,\t shape=Unknown,\t "A:0"\n '
96+
97+ assert std_out .getvalue () == expected_out .lstrip ()
98+
99+
100+ @run_in_graph_mode
101+ def test_numpy ():
102+ """Make sure we can ascii/text print constant tensors with large Numpy arrays."""
103+
104+ with tf .Graph ().as_default ():
105+ A = tf .convert_to_tensor (np .arange (100 ))
106+
107+ std_out = io .StringIO ()
108+ with redirect_stdout (std_out ):
109+ tf_dprint (A )
110+
111+ expected_out = textwrap .dedent ('''
112+ Tensor(Const):0,\t dtype=int64,\t shape=[100],\t "Const:0"
113+ | [ 0 1 2 ... 97 98 99]
114+ ''' )
115+
116+ assert std_out .getvalue () == expected_out .lstrip ()
117+
118+ N = 100
119+ np .random .seed (12345 )
120+ X = np .vstack ([np .random .randn (N ), np .ones (N )]).T
121+
122+ with tf .Graph ().as_default ():
123+ X_tf = tf .convert_to_tensor (X )
124+
125+ std_out = io .StringIO ()
126+ with redirect_stdout (std_out ):
127+ tf_dprint (X_tf )
128+
129+ expected_out = textwrap .dedent ('''
130+ Tensor(Const):0,\t dtype=float64,\t shape=[100, 2],\t "Const:0"
131+ | [[-0.20470766 1. ]
132+ [ 0.47894334 1. ]
133+ [-0.51943872 1. ]
134+ ...
135+ [-0.74853155 1. ]
136+ [ 0.58496974 1. ]
137+ [ 0.15267657 1. ]]
138+ ''' )
77139
78140 assert std_out .getvalue () == expected_out .lstrip ()
0 commit comments