Skip to content

Commit e013c0c

Browse files
Set up kanren's associative and commutative unify for TensorFlow
1 parent 0378f78 commit e013c0c

File tree

5 files changed

+59
-8
lines changed

5 files changed

+59
-8
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from unification import var
2+
3+
from kanren.facts import fact
4+
from kanren.assoccomm import commutative, associative
5+
6+
from ...tensorflow.meta import mt, TFlowMetaOperator
7+
8+
9+
# TODO: We could use `mt.*.op_def.obj.is_commutative` to capture
10+
# more/all cases.
11+
fact(commutative, TFlowMetaOperator(mt.AddV2.op_def, var()))
12+
fact(commutative, TFlowMetaOperator(mt.AddN.op_def, var()))
13+
fact(commutative, TFlowMetaOperator(mt.Mul.op_def, var()))
14+
15+
fact(associative, TFlowMetaOperator(mt.AddN.op_def, var()))
16+
fact(associative, TFlowMetaOperator(mt.AddV2.op_def, var()))

symbolic_pymc/relations/theano/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,22 @@
44

55
from kanren import eq
66
from kanren.core import lall
7+
from kanren.facts import fact
8+
from kanren.assoccomm import commutative, associative
9+
710

811
from .linalg import buildo
912
from ..graph import graph_applyo, seq_apply_anyo
1013
from ...etuple import etuplize, etuple
1114
from ...theano.meta import mt
1215

1316

17+
fact(commutative, mt.add)
18+
fact(commutative, mt.mul)
19+
fact(associative, mt.add)
20+
fact(associative, mt.mul)
21+
22+
1423
def tt_graph_applyo(relation, a, b, preprocess_graph=partial(etuplize, shallow=True)):
1524
"""Construct a `graph_applyo` goal that judiciously expands a Theano meta graph.
1625

symbolic_pymc/tensorflow/meta.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ def _metatize_tf_eager(obj):
210210

211211
meta._metatize.add((TFlowMetaOpDef.base,), TFlowMetaOpDef._metatize)
212212

213+
# Apply TF-specific `kanren` settings
214+
from ..relations import tensorflow
215+
213216
return meta._metatize
214217

215218

symbolic_pymc/theano/meta.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99

1010
from unification import var, isvar, Var
1111

12-
from kanren.facts import fact
13-
from kanren.assoccomm import commutative, associative
14-
1512
from .ops import RandomVariable
1613
from ..meta import (
1714
MetaSymbol,
@@ -51,6 +48,9 @@ def load_dispatcher():
5148
for new_cls in TheanoMetaSymbol.base_subclasses():
5249
meta._metatize.add((new_cls.base,), new_cls._metatize)
5350

51+
# Apply TF-specific `kanren` settings
52+
from ..relations import theano
53+
5454
return meta._metatize
5555

5656

@@ -641,8 +641,3 @@ def mt_diag(v, k=0):
641641

642642

643643
mt.diag = mt_diag
644-
645-
fact(commutative, mt.add)
646-
fact(commutative, mt.mul)
647-
fact(associative, mt.add)
648-
fact(associative, mt.mul)

tests/tensorflow/test_kanren.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from unification import var
2+
3+
from kanren import run
4+
from kanren.assoccomm import eq_comm, commutative
5+
6+
from symbolic_pymc.meta import enable_lvar_defaults
7+
from symbolic_pymc.tensorflow.meta import mt
8+
9+
from tests.tensorflow import run_in_graph_mode
10+
11+
12+
@run_in_graph_mode
13+
def test_commutativity():
14+
with enable_lvar_defaults('names'):
15+
add_1_mt = mt(1) + mt(2)
16+
add_2_mt = mt(2) + mt(1)
17+
18+
res = run(0, var('q'), commutative(add_1_mt.base_operator))
19+
assert res is not False
20+
21+
res = run(0, var('q'), eq_comm(add_1_mt, add_2_mt))
22+
assert res is not False
23+
24+
with enable_lvar_defaults('names'):
25+
add_pattern_mt = mt(2) + var('q')
26+
27+
res = run(0, var('q'), eq_comm(add_1_mt, add_pattern_mt))
28+
assert res[0] == add_1_mt.base_arguments[0]

0 commit comments

Comments
 (0)