|
| 1 | +import tensorflow as tf |
| 2 | + |
| 3 | +from tensorflow.python.framework.ops import disable_tensor_equality |
| 4 | + |
| 5 | +from tensorflow_probability import distributions as tfd |
| 6 | + |
1 | 7 | from unification import var, unify |
2 | 8 |
|
3 | 9 | from kanren import run, eq, lall |
|
6 | 12 |
|
7 | 13 | from symbolic_pymc.meta import enable_lvar_defaults |
8 | 14 | from symbolic_pymc.tensorflow.meta import mt |
| 15 | +from symbolic_pymc.tensorflow.graph import normalize_tf_graph |
9 | 16 |
|
10 | 17 | from tests.tensorflow import run_in_graph_mode |
| 18 | +from tests.tensorflow.utils import mt_normal_log_prob |
| 19 | + |
| 20 | +disable_tensor_equality() |
11 | 21 |
|
12 | 22 |
|
13 | 23 | @run_in_graph_mode |
@@ -47,3 +57,36 @@ def test_commutativity(): |
47 | 57 |
|
48 | 58 | res = run(0, q, eq_comm(add_1_mt, add_pattern_mt)) |
49 | 59 | assert res[0] == add_1_mt.base_arguments[0] |
| 60 | + |
| 61 | + |
| 62 | +@run_in_graph_mode |
| 63 | +def test_commutativity_tfp(): |
| 64 | + |
| 65 | + with tf.Graph().as_default(): |
| 66 | + mu_tf = tf.compat.v1.placeholder(tf.float32, name="mu", shape=tf.TensorShape([None])) |
| 67 | + tau_tf = tf.compat.v1.placeholder(tf.float32, name="tau", shape=tf.TensorShape([None])) |
| 68 | + |
| 69 | + normal_tfp = tfd.normal.Normal(mu_tf, tau_tf) |
| 70 | + |
| 71 | + value_tf = tf.compat.v1.placeholder(tf.float32, name="value", shape=tf.TensorShape([None])) |
| 72 | + |
| 73 | + normal_log_lik = normal_tfp.log_prob(value_tf) |
| 74 | + |
| 75 | + normal_log_lik_opt = normalize_tf_graph(normal_log_lik) |
| 76 | + |
| 77 | + with enable_lvar_defaults("names", "node_attrs"): |
| 78 | + tfp_normal_pattern_mt = mt_normal_log_prob(var(), var(), var()) |
| 79 | + |
| 80 | + normal_log_lik_mt = mt(normal_log_lik) |
| 81 | + normal_log_lik_opt_mt = mt(normal_log_lik_opt) |
| 82 | + |
| 83 | + # Our pattern is the form of an unnormalized TFP normal PDF. |
| 84 | + assert run(0, True, eq(normal_log_lik_mt, tfp_normal_pattern_mt)) == (True,) |
| 85 | + # Our pattern should *not* match the Grappler-optimized graph, because |
| 86 | + # Grappler will reorder terms (e.g. the log + constant |
| 87 | + # variance/normalization term) |
| 88 | + assert run(0, True, eq(normal_log_lik_opt_mt, tfp_normal_pattern_mt)) == () |
| 89 | + |
| 90 | + # XXX: `eq_comm` is, unfortunately, order sensitive! LHS should be ground. |
| 91 | + assert run(0, True, eq_comm(normal_log_lik_mt, tfp_normal_pattern_mt)) == (True,) |
| 92 | + assert run(0, True, eq_comm(normal_log_lik_opt_mt, tfp_normal_pattern_mt)) == (True,) |
0 commit comments