Skip to content

Commit e494172

Browse files
Add dtype and NumPy Const printing to TensorFlow debug print
1 parent e644c1f commit e494172

File tree

2 files changed

+109
-26
lines changed

2 files changed

+109
-26
lines changed

symbolic_pymc/tensorflow/printing.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import sys
22

3+
import numpy as np
4+
import tensorflow as tf
5+
36
from functools import singledispatch
47
from contextlib import contextmanager
58

6-
import tensorflow as tf
9+
from unification import isvar
10+
11+
# from tensorflow.python.framework import tensor_util
712

813
from symbolic_pymc.tensorflow.meta import TFlowMetaTensor
914
from symbolic_pymc.tensorflow.meta import TFlowMetaOp
@@ -77,22 +82,38 @@ def _(obj, printer):
7782
except (ValueError, AttributeError):
7883
shape_str = "Unknown"
7984

80-
prefix = f'Tensor({obj.op.type}):{obj.value_index},\tshape={shape_str}\t"{obj.name}"'
85+
prefix = f'Tensor({getattr(obj.op, "type", obj.op)}):{obj.value_index},\tdtype={getattr(obj.dtype, "name", obj.dtype)},\tshape={shape_str},\t"{obj.name}"'
8186
_tf_dprint(prefix, printer)
82-
if len(obj.op.inputs) > 0:
87+
88+
if isvar(obj.op):
89+
return
90+
elif isvar(obj.op.inputs):
91+
with printer.indented("| "):
92+
_tf_dprint(f"{obj.op.inputs}", printer)
93+
elif len(obj.op.inputs) > 0:
8394
with printer.indented("| "):
8495
if obj not in printer.printed_subgraphs:
8596
printer.printed_subgraphs.add(obj)
8697
_tf_dprint(obj.op, printer)
8798
else:
8899
_tf_dprint("...", printer)
100+
elif obj.op.type == "Const":
101+
with printer.indented("| "):
102+
if isinstance(obj, tf.Tensor):
103+
numpy_val = obj.eval(session=tf.compat.v1.Session(graph=obj.graph))
104+
elif isvar(obj.op.node_def):
105+
_tf_dprint(f"{obj.op.node_def}", printer)
106+
return
107+
else:
108+
numpy_val = obj.op.node_def.attr["value"]
109+
110+
_tf_dprint(
111+
np.array2string(numpy_val, threshold=20, prefix=printer.indentation), printer
112+
)
89113

90114

91115
@_tf_dprint.register(tf.Operation)
92116
@_tf_dprint.register(TFlowMetaOp)
93117
def _(obj, printer):
94-
prefix = f'Op({obj.type})\t"{obj.name}"'
95-
_tf_dprint(prefix, printer)
96-
with printer.indented("| "):
97-
for op_input in obj.inputs:
98-
_tf_dprint(op_input, printer)
118+
for op_input in obj.inputs:
119+
_tf_dprint(op_input, printer)

tests/tensorflow/test_printing.py

Lines changed: 80 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
from contextlib import redirect_stdout
99

10-
from tensorflow.python.eager.context import graph_mode
10+
from unification import var, Var
1111

12+
from symbolic_pymc.tensorflow.meta import mt
1213
from symbolic_pymc.tensorflow.printing import tf_dprint
1314

1415
from tests.tensorflow import run_in_graph_mode
@@ -23,11 +24,7 @@ def test_eager_mode():
2324
X_tf = tf.convert_to_tensor(X)
2425

2526
with pytest.raises(ValueError):
26-
tf_dprint(X_tf)
27-
28-
with graph_mode():
29-
X_tf = tf.convert_to_tensor(X)
30-
tf_dprint(X_tf)
27+
_ = tf_dprint(X_tf)
3128

3229

3330
@run_in_graph_mode
@@ -47,17 +44,39 @@ def test_ascii_printing():
4744
tf_dprint(z)
4845

4946
expected_out = textwrap.dedent('''
50-
Tensor(MatMul):0,\tshape=[None, 1]\t"A_dot:0"
51-
| Op(MatMul)\t"A_dot"
52-
| | Tensor(Placeholder):0,\tshape=[None, None]\t"A:0"
53-
| | Tensor(Add):0,\tshape=[None, 1]\t"x_p_y:0"
54-
| | | Op(Add)\t"x_p_y"
55-
| | | | Tensor(Mul):0, shape=[None, 1] "y:0"
56-
| | | | | Op(Mul)\t"y"
57-
| | | | | | Tensor(Const):0,\tshape=[]\t"y/x:0"
58-
| | | | | | Tensor(Placeholder):0,\tshape=[None, 1]\t"x:0"
59-
| | | | Tensor(Mul):0,\tshape=[None, 1]\t"y:0"
60-
| | | | | ...
47+
Tensor(MatMul):0,\tdtype=float32,\tshape=[None, 1],\t"A_dot:0"
48+
| Tensor(Placeholder):0,\tdtype=float32,\tshape=[None, None],\t"A:0"
49+
| Tensor(Add):0,\tdtype=float32,\tshape=[None, 1],\t"x_p_y:0"
50+
| | Tensor(Mul):0,\tdtype=float32,\tshape=[None, 1],\t"y:0"
51+
| | | Tensor(Const):0,\tdtype=float32,\tshape=[],\t"y/x:0"
52+
| | | | 1.
53+
| | | Tensor(Placeholder):0,\tdtype=float32,\tshape=[None, 1],\t"x:0"
54+
| | Tensor(Mul):0,\tdtype=float32,\tshape=[None, 1],\t"y:0"
55+
| | | ...
56+
''')
57+
58+
assert std_out.getvalue() == expected_out.lstrip()
59+
60+
std_out = io.StringIO()
61+
with tf.Graph().as_default(), redirect_stdout(std_out):
62+
Var._id = 0
63+
tt_lv_inputs_mt = mt.Tensor(mt.Operation(var(), var(), var()), 0, var())
64+
tt_const_lv_nodedef_mt = mt.Tensor(mt.Operation(mt.Const.op_def, var(), ()), 0, var())
65+
tt_lv_op_mt = mt.Tensor(var(), 0, var())
66+
test_mt = mt(1) + tt_lv_inputs_mt + tt_const_lv_nodedef_mt + tt_lv_op_mt
67+
tf_dprint(test_mt)
68+
69+
expected_out = textwrap.dedent('''
70+
Tensor(AddV2):0,\tdtype=int32,\tshape=~_11,\t"add:0"
71+
| Tensor(AddV2):0,\tdtype=int32,\tshape=~_12,\t"add:0"
72+
| | Tensor(AddV2):0,\tdtype=int32,\tshape=~_13,\t"add:0"
73+
| | | Tensor(Const):0,\tdtype=int32,\tshape=[],\t"Const:0"
74+
| | | | 1
75+
| | | Tensor(~_15):0,\tdtype=~_3,\tshape=~_14,\t"~_17"
76+
| | | | ~_2
77+
| | Tensor(Const):0,\tdtype=~_5,\tshape=~_18,\t"~_20"
78+
| | | ~_4
79+
| Tensor(~_6):0,\tdtype=~_7,\tshape=~_21,\t"~_22"
6180
''')
6281

6382
assert std_out.getvalue() == expected_out.lstrip()
@@ -73,6 +92,49 @@ def test_unknown_shape():
7392
with redirect_stdout(std_out):
7493
tf_dprint(A)
7594

76-
expected_out = 'Tensor(Placeholder):0,\tshape=Unknown\t"A:0"\n'
95+
expected_out = 'Tensor(Placeholder):0,\tdtype=float64,\tshape=Unknown,\t"A:0"\n'
96+
97+
assert std_out.getvalue() == expected_out.lstrip()
98+
99+
100+
@run_in_graph_mode
101+
def test_numpy():
102+
"""Make sure we can ascii/text print constant tensors with large Numpy arrays."""
103+
104+
with tf.Graph().as_default():
105+
A = tf.convert_to_tensor(np.arange(100))
106+
107+
std_out = io.StringIO()
108+
with redirect_stdout(std_out):
109+
tf_dprint(A)
110+
111+
expected_out = textwrap.dedent('''
112+
Tensor(Const):0,\tdtype=int64,\tshape=[100],\t"Const:0"
113+
| [ 0 1 2 ... 97 98 99]
114+
''')
115+
116+
assert std_out.getvalue() == expected_out.lstrip()
117+
118+
N = 100
119+
np.random.seed(12345)
120+
X = np.vstack([np.random.randn(N), np.ones(N)]).T
121+
122+
with tf.Graph().as_default():
123+
X_tf = tf.convert_to_tensor(X)
124+
125+
std_out = io.StringIO()
126+
with redirect_stdout(std_out):
127+
tf_dprint(X_tf)
128+
129+
expected_out = textwrap.dedent('''
130+
Tensor(Const):0,\tdtype=float64,\tshape=[100, 2],\t"Const:0"
131+
| [[-0.20470766 1. ]
132+
[ 0.47894334 1. ]
133+
[-0.51943872 1. ]
134+
...
135+
[-0.74853155 1. ]
136+
[ 0.58496974 1. ]
137+
[ 0.15267657 1. ]]
138+
''')
77139

78140
assert std_out.getvalue() == expected_out.lstrip()

0 commit comments

Comments
 (0)