Skip to content

Commit 28556fd

Browse files
Add a Grappler normalize function
1 parent e7a4c59 commit 28556fd

File tree

6 files changed

+164
-3
lines changed

6 files changed

+164
-3
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ ignore-mixin-members=yes
263263
# (useful for modules/projects where namespaces are manipulated during runtime
264264
# and thus existing member attributes cannot be deduced by static analysis. It
265265
# supports qualified module names, as well as Unix pattern matching.
266-
ignored-modules=tensorflow.core.framework,tensorflow.python.framework,tensorflow.python.ops.gen_linalg_ops,tensorflow.python.eager.context,tensorflow.compat.v1
266+
ignored-modules=tensorflow.core.framework,tensorflow.python.framework,tensorflow.python.ops.gen_linalg_ops,tensorflow.python.eager.context,tensorflow.compat.v1,tensorflow.core.protobuf,tensorflow.python.grappler
267267

268268
# List of classes names for which member attributes should not be checked
269269
# (useful for classes with attributes dynamically set). This supports can work

symbolic_pymc/tensorflow/graph.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import tensorflow as tf
2+
3+
from tensorflow.core.protobuf import config_pb2
4+
5+
from tensorflow.python.framework import ops
6+
from tensorflow.python.framework import importer
7+
from tensorflow.python.framework import meta_graph
8+
9+
from tensorflow.python.grappler import cluster
10+
from tensorflow.python.grappler import tf_optimizer
11+
12+
13+
try: # pragma: no cover
14+
gcluster = cluster.Cluster()
15+
except tf.errors.UnavailableError: # pragma: no cover
16+
pass
17+
18+
config = config_pb2.ConfigProto()
19+
20+
21+
def normalize_tf_graph(graph_output, new_graph=True, verbose=False):
22+
"""Use grappler to normalize a graph.
23+
24+
Arguments
25+
---------
26+
graph_output: Tensor
27+
A tensor we want to consider as "output" of a `FuncGraph`.
28+
29+
Returns
30+
-------
31+
The simplified graph.
32+
"""
33+
train_op = graph_output.graph.get_collection_ref(ops.GraphKeys.TRAIN_OP)
34+
train_op.clear()
35+
train_op.extend([graph_output])
36+
37+
metagraph = meta_graph.create_meta_graph_def(graph=graph_output.graph)
38+
39+
optimized_graphdef = tf_optimizer.OptimizeGraph(
40+
config, metagraph, verbose=verbose, cluster=gcluster
41+
)
42+
43+
output_name = graph_output.name
44+
45+
if new_graph:
46+
optimized_graph = ops.Graph()
47+
else: # pragma: no cover
48+
optimized_graph = ops.get_default_graph()
49+
del graph_output
50+
51+
with optimized_graph.as_default():
52+
importer.import_graph_def(optimized_graphdef, name="")
53+
54+
opt_graph_output = optimized_graph.get_tensor_by_name(output_name)
55+
56+
return opt_graph_output

symbolic_pymc/tensorflow/meta.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from inspect import Parameter, Signature
88

9-
from collections import OrderedDict, Sequence
9+
from collections import OrderedDict
10+
from collections.abc import Sequence
1011

1112
from functools import partial
1213

tests/tensorflow/test_graph.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
4+
from symbolic_pymc.tensorflow.graph import normalize_tf_graph
5+
6+
from tests.tensorflow import run_in_graph_mode
7+
8+
9+
@run_in_graph_mode
10+
def test_normalize():
11+
12+
tf.config.optimizer.set_experimental_options(
13+
{
14+
"shape_optimizations": True,
15+
"arithmetic_optimzation": True,
16+
"function_optimization": True,
17+
"min_graph_nodes": 0,
18+
}
19+
)
20+
with tf.Graph().as_default() as norm_graph:
21+
a_tf = tf.compat.v1.placeholder("float")
22+
const_log_tf = 0.5 * np.log(2.0 * np.pi) + tf.math.log(a_tf)
23+
normal_const_log_tf = normalize_tf_graph(const_log_tf)
24+
25+
# Grappler appears to put log ops before const
26+
assert normal_const_log_tf.op.inputs[0].op.type == "Log"
27+
assert normal_const_log_tf.op.inputs[1].op.type == "Const"

tests/tensorflow/test_kanren.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
import tensorflow as tf
2+
3+
from tensorflow.python.framework.ops import disable_tensor_equality
4+
5+
from tensorflow_probability import distributions as tfd
6+
17
from unification import var, unify
28

39
from kanren import run, eq, lall
@@ -6,8 +12,12 @@
612

713
from symbolic_pymc.meta import enable_lvar_defaults
814
from symbolic_pymc.tensorflow.meta import mt
15+
from symbolic_pymc.tensorflow.graph import normalize_tf_graph
916

1017
from tests.tensorflow import run_in_graph_mode
18+
from tests.tensorflow.utils import mt_normal_log_prob
19+
20+
disable_tensor_equality()
1121

1222

1323
@run_in_graph_mode
@@ -47,3 +57,36 @@ def test_commutativity():
4757

4858
res = run(0, q, eq_comm(add_1_mt, add_pattern_mt))
4959
assert res[0] == add_1_mt.base_arguments[0]
60+
61+
62+
@run_in_graph_mode
63+
def test_commutativity_tfp():
64+
65+
with tf.Graph().as_default():
66+
mu_tf = tf.compat.v1.placeholder(tf.float32, name="mu", shape=tf.TensorShape([None]))
67+
tau_tf = tf.compat.v1.placeholder(tf.float32, name="tau", shape=tf.TensorShape([None]))
68+
69+
normal_tfp = tfd.normal.Normal(mu_tf, tau_tf)
70+
71+
value_tf = tf.compat.v1.placeholder(tf.float32, name="value", shape=tf.TensorShape([None]))
72+
73+
normal_log_lik = normal_tfp.log_prob(value_tf)
74+
75+
normal_log_lik_opt = normalize_tf_graph(normal_log_lik)
76+
77+
with enable_lvar_defaults("names", "node_attrs"):
78+
tfp_normal_pattern_mt = mt_normal_log_prob(var(), var(), var())
79+
80+
normal_log_lik_mt = mt(normal_log_lik)
81+
normal_log_lik_opt_mt = mt(normal_log_lik_opt)
82+
83+
# Our pattern is the form of an unnormalized TFP normal PDF.
84+
assert run(0, True, eq(normal_log_lik_mt, tfp_normal_pattern_mt)) == (True,)
85+
# Our pattern should *not* match the Grappler-optimized graph, because
86+
# Grappler will reorder terms (e.g. the log + constant
87+
# variance/normalization term)
88+
assert run(0, True, eq(normal_log_lik_opt_mt, tfp_normal_pattern_mt)) == ()
89+
90+
# XXX: `eq_comm` is, unfortunately, order sensitive! LHS should be ground.
91+
assert run(0, True, eq_comm(normal_log_lik_mt, tfp_normal_pattern_mt)) == (True,)
92+
assert run(0, True, eq_comm(normal_log_lik_opt_mt, tfp_normal_pattern_mt)) == (True,)

tests/tensorflow/utils.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from collections.abc import Mapping
1+
import numpy as np
22

33
import tensorflow as tf
44
from tensorflow.python.framework import ops
55

6+
from collections.abc import Mapping
7+
8+
from symbolic_pymc.tensorflow.meta import mt
9+
610

711
def assert_ops_equal(a, b, compare_fn=lambda a, b: a.op.type == b.op.type):
812
if hasattr(a, "op") or hasattr(b, "op"):
@@ -32,3 +36,33 @@ def assert_ops_equal(a, b, compare_fn=lambda a, b: a.op.type == b.op.type):
3236

3337
for i_a, i_b in zip(a_inputs, b_inputs):
3438
assert_ops_equal(i_a, i_b)
39+
40+
41+
def tfp_normal_log_prob(x, loc, scale):
42+
"""Create a graph of the Grappler-canonicalized form of a TFP normal log-likelihood."""
43+
log_unnormalized = -0.5 * tf.math.squared_difference(x / scale, loc / scale)
44+
log_normalization = 0.5 * np.log(2.0 * np.pi) + tf.math.log(scale)
45+
return log_unnormalized - log_normalization
46+
47+
48+
def mt_normal_log_prob(x, loc, scale):
49+
"""Create a meta graph for Grappler-canonicalized standard or non-standard TFP normal log-likelihoods."""
50+
if loc == 0:
51+
log_unnormalized_mt = mt(np.array(-0.5, "float32"))
52+
log_unnormalized_mt *= mt.squareddifference(
53+
mt(np.array(0.0, "float32")),
54+
mt.realdiv(x, scale) if scale != 1 else mt.mul(np.array(1.0, "float32"), x),
55+
)
56+
else:
57+
log_unnormalized_mt = mt(np.array(-0.5, "float32"))
58+
log_unnormalized_mt *= mt.squareddifference(
59+
mt.realdiv(x, scale) if scale != 1 else mt.mul(np.array(1.0, "float32"), x),
60+
mt.realdiv(loc, scale) if scale != 1 else mt.mul(np.array(1.0, "float32"), loc),
61+
)
62+
63+
log_normalization_mt = mt((0.5 * np.log(2.0 * np.pi)).astype("float32"))
64+
65+
if scale != 1:
66+
log_normalization_mt = log_normalization_mt + mt.log(scale)
67+
68+
return log_unnormalized_mt - log_normalization_mt

0 commit comments

Comments
 (0)