Skip to content

Commit e644c1f

Browse files
Merge pull request #88 from brandonwillard/add-meta-operator
Refactor etuplization and introduce TensorFlow Operator meta class
2 parents 3aa5923 + e013c0c commit e644c1f

File tree

16 files changed

+716
-368
lines changed

16 files changed

+716
-368
lines changed

symbolic_pymc/etuple.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from collections import Sequence
77

8-
from cons.core import ConsPair
8+
from cons.core import ConsPair, ConsNull
99

1010
from multipledispatch import dispatch
1111

@@ -112,6 +112,8 @@ def eval_obj(self):
112112
evaled_kwargs = arg_grps.get(True, [])
113113

114114
op = self._tuple[0]
115+
op = getattr(op, "eval_obj", op)
116+
115117
try:
116118
op_sig = inspect.signature(op)
117119
except ValueError:
@@ -234,6 +236,8 @@ def etuplize(x, shallow=False, return_bad_args=False):
234236
"""
235237
if isinstance(x, ExpressionTuple):
236238
return x
239+
elif x is not None and isinstance(x, (ConsNull, ConsPair)):
240+
return etuple(*x)
237241

238242
try:
239243
# This can throw an `IndexError` if `x` is an empty
@@ -242,24 +246,22 @@ def etuplize(x, shallow=False, return_bad_args=False):
242246
args = arguments(x)
243247
except (IndexError, NotImplementedError):
244248
op = None
245-
args = x
249+
args = None
246250

247-
if not isinstance(args, ConsPair):
251+
if not callable(op) or not isinstance(args, (ConsNull, ConsPair)):
248252
if return_bad_args:
249253
return x
250254
else:
251255
raise TypeError(f"x is neither a non-str Sequence nor term: {type(x)}")
252256

253-
# Not everything in a list/tuple should be considered an expression.
254-
if not callable(op):
255-
return etuple(*x)
256-
257257
if shallow:
258+
et_op = op
258259
et_args = args
259260
else:
261+
et_op = etuplize(op, return_bad_args=True)
260262
et_args = tuple(etuplize(a, return_bad_args=True) for a in args)
261263

262-
res = etuple(op, *et_args, eval_obj=x)
264+
res = etuple(et_op, *et_args, eval_obj=x)
263265
return res
264266

265267

symbolic_pymc/meta.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def obj(self):
239239

240240
@classmethod
241241
def base_subclasses(cls):
242+
"""Return all meta symbols with valid, implemented bases (i.e. base property is a type object)."""
242243
for subclass in cls.__subclasses__():
243244
yield from subclass.base_subclasses()
244245
if isinstance(subclass.base, type):
@@ -252,12 +253,18 @@ def __init__(self, obj=None):
252253
assert obj is None or isvar(obj) or isinstance(obj, self.base)
253254
self._obj = obj
254255

256+
@property
255257
def rands(self):
256258
"""Get a tuple of the meta object's operator parameters (i.e. "rands")."""
257259
if getattr(self, "_rands", None) is not None:
258260
return self._rands
259261

260-
self._rands = tuple(getattr(self, s) for s in getattr(self, "__all_props__", ()))
262+
all_props = getattr(self, "__all_props__", None)
263+
264+
if all_props:
265+
self._rands = tuple(getattr(self, s) for s in all_props)
266+
else:
267+
raise NotImplementedError()
261268

262269
return self._rands
263270

@@ -282,7 +289,7 @@ def reify(self):
282289
if self.obj is not None and not isinstance(self.obj, Var):
283290
return self.obj
284291
else:
285-
reified_rands, any_unreified = meta_reify_iter(self.rands())
292+
reified_rands, any_unreified = meta_reify_iter(self.rands)
286293

287294
# If not all the rands reified, then create another meta
288295
# object--albeit one with potentially more non-`None` `obj` fields.
@@ -304,8 +311,13 @@ def __eq__(self, other):
304311

305312
assert self.base == other.base
306313

307-
if self.rands():
308-
return all(s == o for s, o in zip(self.rands(), other.rands()))
314+
try:
315+
rands = self.rands
316+
except NotImplementedError:
317+
rands = None
318+
319+
if rands:
320+
return all(s == o for s, o in zip(self.rands, other.rands))
309321
else:
310322
return NotImplemented
311323

@@ -315,16 +327,19 @@ def __ne__(self, other):
315327
return not self.__eq__(other)
316328

317329
def __hash__(self):
318-
return hash((self.base, self.rands()))
330+
return hash((self.base, self.rands))
319331

320332
def __str__(self):
321333
return self.__repr__(show_obj=False, _repr=str)
322334

323335
def __repr__(self, show_obj=True, _repr=meta_repr.repr):
324-
rands = self.rands()
336+
try:
337+
rands = self.rands
338+
except NotImplementedError:
339+
rands = None
325340

326341
if rands:
327-
args = _repr(self.rands())[1:-1]
342+
args = _repr(self.rands)[1:-1]
328343
else:
329344
args = ""
330345

@@ -344,8 +359,13 @@ def _repr_pretty_(self, p, cycle):
344359
with p.group(2, f"{self.__class__.__name__}(", ")"):
345360
p.breakable(sep="")
346361
idx = None
347-
if hasattr(self, "__all_props__"):
348-
for idx, (name, item) in enumerate(zip(self.__all_props__, self.rands())):
362+
try:
363+
rands = self.rands
364+
except NotImplementedError:
365+
rands = None
366+
367+
if rands:
368+
for idx, (name, item) in enumerate(zip(self.__all_props__, rands)):
349369
if idx:
350370
p.text(",")
351371
p.breakable()
@@ -386,7 +406,7 @@ def __init__(self, *args, **kwargs):
386406
super().__init__(*args, **kwargs)
387407

388408
@abc.abstractmethod
389-
def out_meta_types(self, inputs=None):
409+
def output_meta_types(self, inputs=None):
390410
"""Return the types of meta variables this `Op` is expected to produce given the inputs."""
391411
raise NotImplementedError()
392412

@@ -411,18 +431,22 @@ class MetaVariable(MetaSymbol):
411431

412432
@property
413433
@abc.abstractmethod
414-
def operator(self):
415-
"""Return a meta object representing an operator, if any, capable of producing this variable.
434+
def base_operator(self):
435+
"""Return a meta object representing a base-level operator.
416436
417437
It should be callable with all inputs necessary to reproduce this
418-
tensor given by `MetaVariable.inputs`.
438+
tensor given by `self.base_arguments`.
419439
"""
420440
raise NotImplementedError()
421441

422442
@property
423443
@abc.abstractmethod
424-
def inputs(self):
425-
"""Return the inputs necessary for `MetaVariable.operator` to produced this variable, if any."""
444+
def base_arguments(self):
445+
"""Return the base-level arguments.
446+
447+
These arguments used in conjunction with the callable
448+
`self.base_operator` should re-produce this variable.
449+
"""
426450
raise NotImplementedError()
427451

428452

@@ -431,7 +455,12 @@ def _find_meta_type(obj_type, meta_abs_type):
431455
obj_cls = None
432456
while True:
433457
try:
434-
obj_cls = next(filter(lambda t: issubclass(obj_type, t.base), cls.__subclasses__()))
458+
obj_cls = next(
459+
filter(
460+
lambda t: isinstance(t.base, type) and issubclass(obj_type, t.base),
461+
cls.__subclasses__(),
462+
)
463+
)
435464
except StopIteration:
436465
# The current class is the best fit.
437466
if cls.base == obj_type:

symbolic_pymc/relations/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from kanren.term import term, operator, arguments
88
from kanren.goals import conso
99

10-
from ..etuple import etuplize, ExpressionTuple
10+
from ..etuple import etuplize
1111

1212

1313
# Hierarchical models that we recognize.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from unification import var
2+
3+
from kanren.facts import fact
4+
from kanren.assoccomm import commutative, associative
5+
6+
from ...tensorflow.meta import mt, TFlowMetaOperator
7+
8+
9+
# TODO: We could use `mt.*.op_def.obj.is_commutative` to capture
10+
# more/all cases.
11+
fact(commutative, TFlowMetaOperator(mt.AddV2.op_def, var()))
12+
fact(commutative, TFlowMetaOperator(mt.AddN.op_def, var()))
13+
fact(commutative, TFlowMetaOperator(mt.Mul.op_def, var()))
14+
15+
fact(associative, TFlowMetaOperator(mt.AddN.op_def, var()))
16+
fact(associative, TFlowMetaOperator(mt.AddV2.op_def, var()))

symbolic_pymc/relations/theano/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,22 @@
44

55
from kanren import eq
66
from kanren.core import lall
7+
from kanren.facts import fact
8+
from kanren.assoccomm import commutative, associative
9+
710

811
from .linalg import buildo
912
from ..graph import graph_applyo, seq_apply_anyo
1013
from ...etuple import etuplize, etuple
1114
from ...theano.meta import mt
1215

1316

17+
fact(commutative, mt.add)
18+
fact(commutative, mt.mul)
19+
fact(associative, mt.add)
20+
fact(associative, mt.mul)
21+
22+
1423
def tt_graph_applyo(relation, a, b, preprocess_graph=partial(etuplize, shallow=True)):
1524
"""Construct a `graph_applyo` goal that judiciously expands a Theano meta graph.
1625

0 commit comments

Comments
 (0)