Skip to content

Commit 5a7ec36

Browse files
Make TF meta objects follow their base APIs more closely
* More memoization/caching has been added along with a small amount of refactoring to the meta API. * Hashing for special cases (e.g. Numpy arrays and dicts) has been moved closer to their sources (e.g. constant tensor classes and TF NodeDef). * Superfluous TF Python meta types have been removed. * The name property for TF Tensors is now derived directly from the name field in the corresponding NodeDef. * Dispatch functions now have distinct names (good for debugging). * metatize for types/classes returns a meta type/class, not an instance of the type/class
1 parent 8433b9d commit 5a7ec36

File tree

10 files changed

+685
-426
lines changed

10 files changed

+685
-426
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ pytest-cov>=2.6.1
44
pytest-html>=1.20.0
55
pylint>=2.3.1
66
black>=19.3b0
7+
ipython

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ unification>=0.2.2
1111
kanren @ git+https://github.com/pymc-devs/kanren.git@symbolic-pymc#egg=kanren-0.2.3
1212
toolz>=0.9.0
1313
sympy>=1.3
14+
cachetools

symbolic_pymc/meta.py

Lines changed: 90 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import types
33
import reprlib
44

5-
import numpy as np
6-
75
from itertools import chain
86
from functools import partial
97
from collections import OrderedDict
@@ -15,11 +13,15 @@
1513

1614
from multipledispatch import dispatch
1715

16+
from cachetools import cached
17+
1818
meta_repr = reprlib.Repr()
1919
meta_repr.maxstring = 100
2020
meta_repr.maxother = 100
2121
meta_repr.print_obj = False
2222

23+
metatize_cache = {}
24+
2325

2426
def metatize(obj):
2527
"""Convert object to base type then meta object."""
@@ -28,38 +30,39 @@ def metatize(obj):
2830
return _metatize(obj)
2931

3032

31-
@dispatch((set, list, tuple))
33+
@dispatch((type(None), types.FunctionType, partial, str, dict))
3234
def _metatize(obj):
35+
return obj
36+
37+
38+
@_metatize.register((set, tuple))
39+
@cached(metatize_cache)
40+
def _metatize_set_tuple(obj):
3341
"""Convert elements of an iterable to meta objects."""
3442
return type(obj)([metatize(o) for o in obj])
3543

3644

37-
@dispatch(Iterator)
38-
def _metatize(obj):
39-
"""Convert elements of an iterator to meta objects."""
40-
return iter([metatize(o) for o in obj])
45+
@_metatize.register(list)
46+
def _metatize_list(obj):
47+
"""Convert elements of an iterable to meta objects."""
48+
return type(obj)([metatize(o) for o in obj])
4149

4250

43-
def _make_hashable(x):
44-
if isinstance(x, list):
45-
return tuple(x)
46-
elif isinstance(x, Mapping):
47-
return frozenset(x.items())
48-
elif isinstance(x, np.ndarray):
49-
return x.tostring()
50-
else:
51-
return x
51+
@_metatize.register(Iterator)
52+
@cached(metatize_cache)
53+
def _metatize_Iterator(obj):
54+
"""Convert elements of an iterator to meta objects."""
55+
return iter([metatize(o) for o in obj])
5256

5357

54-
def _meta_reify_iter(rands):
58+
def meta_reify_iter(rands):
5559
"""Recursively reify an iterable object and return a boolean indicating the presence of un-reifiable objects, if any."""
56-
# We want as many of the rands reified as possible,
5760
any_unreified = False
5861
reified_rands = []
59-
if isinstance(rands, Mapping):
60-
_rands = rands.items()
61-
else:
62-
_rands = rands
62+
63+
_rands = rands
64+
if isinstance(_rands, Mapping):
65+
_rands = _rands.items()
6366

6467
for s in _rands:
6568
if isinstance(s, MetaSymbol):
@@ -71,11 +74,11 @@ def _meta_reify_iter(rands):
7174
reified_rands.append(s)
7275
any_unreified |= True
7376
elif isinstance(s, (list, tuple)):
74-
_reified_rands, _any_unreified = _meta_reify_iter(s)
77+
_reified_rands, _any_unreified = meta_reify_iter(s)
7578
reified_rands.append(type(s)(_reified_rands))
7679
any_unreified |= _any_unreified
7780
else:
78-
reified_rands += [s]
81+
reified_rands.append(s)
7982

8083
return type(rands)(reified_rands), any_unreified
8184

@@ -153,8 +156,6 @@ def _cached_hash(self):
153156

154157
new_cls.__hash__ = _cached_hash
155158

156-
# TODO: Could register base classes.
157-
# E.g. cls.register(bases)
158159
return new_cls
159160

160161

@@ -164,7 +165,7 @@ class MetaSymbol(metaclass=MetaSymbolType):
164165
TODO: Should `MetaSymbol.obj` be an abstract property and a `weakref`?
165166
"""
166167

167-
__slots__ = ("_obj", "_hash")
168+
__slots__ = ("_obj", "_hash", "_rands")
168169

169170
@property
170171
@abc.abstractmethod
@@ -189,18 +190,40 @@ def is_meta(cls, obj):
189190
return isinstance(obj, MetaSymbol) or isvar(obj)
190191

191192
def __init__(self, obj=None):
193+
assert obj is None or isvar(obj) or isinstance(obj, self.base)
192194
self._obj = obj
193195

194196
def rands(self):
195-
"""Create a tuple of the meta object's operator parameters (i.e. "rands")."""
196-
return tuple(getattr(self, s) for s in getattr(self, "__all_props__", ()))
197+
"""Get a tuple of the meta object's operator parameters (i.e. "rands")."""
198+
if getattr(self, "_rands", None) is not None:
199+
return self._rands
200+
201+
self._rands = tuple(getattr(self, s) for s in getattr(self, "__all_props__", ()))
202+
203+
return self._rands
197204

198205
def reify(self):
199-
"""Create a concrete base object from this meta object (and its rands)."""
206+
"""Attempt to create a concrete base object from this meta object.
207+
208+
During the process, dependent objects will need to be reified, which
209+
may result in updates to the object(s) being reified.
210+
211+
For instance, if a meta tensor's parent operator is fully reifiable to
212+
a base object, then the meta tensor's dtype and shape may be fixed:
213+
e.g. a tensor corresponding to the output of a sum of two float64
214+
scalars is necessarily a float64 scalar.
215+
216+
This function will set any unspecified properties (e.g. dtype and shape
217+
values for the previous example), mutating the object in-place when
218+
possible. It will return a [refined/partially reified] meta object
219+
when it can't fully reify to a base object (in which case, it will
220+
return the base object) or when partial reification results in a meta
221+
object from a subclass.
222+
"""
200223
if self.obj is not None and not isinstance(self.obj, Var):
201224
return self.obj
202225
else:
203-
reified_rands, any_unreified = _meta_reify_iter(self.rands())
226+
reified_rands, any_unreified = meta_reify_iter(self.rands())
204227

205228
# If not all the rands reified, then create another meta
206229
# object--albeit one with potentially more non-`None` `obj` fields.
@@ -220,35 +243,20 @@ def __eq__(self, other):
220243
if not (type(self) == type(other)):
221244
return False
222245

223-
if not (self.base == other.base):
224-
return False
246+
assert self.base == other.base
225247

226-
a_slots = getattr(self, "__all_props__", None)
227-
if a_slots is not None:
228-
if not all(_check_eq(getattr(self, attr), getattr(other, attr)) for attr in a_slots):
229-
return False
230-
elif getattr(other, "__all_props__", None) is not None:
231-
# The other object has slots, but this one doesn't.
232-
return False
248+
if self.rands():
249+
return all(_check_eq(s, o) for s, o in zip(self.rands(), other.rands()))
233250
else:
234-
# Neither have slots, so best we can do is compare
235-
# base objects (if any).
236-
# If there aren't base objects, we say they're not equal.
237-
# (Maybe we should *require* base objects in this case
238-
# and raise an exception?)
239-
return getattr(self, "obj", None) == getattr(other, "obj", None) is not None
251+
return NotImplemented
240252

241-
return True
253+
return False
242254

243255
def __ne__(self, other):
244256
return not self.__eq__(other)
245257

246258
def __hash__(self):
247-
if getattr(self, "__props__", None) is not None:
248-
rands = tuple(_make_hashable(p) for p in self.rands())
249-
return hash(rands + (self.base,))
250-
else:
251-
return hash((self.base, self.obj))
259+
return hash((self.base, self.rands()))
252260

253261
def __str__(self):
254262
obj = getattr(self, "obj", None)
@@ -273,8 +281,8 @@ def _repr_pretty_(self, p, cycle):
273281
with p.group(2, f"{self.__class__.__name__}(", ")"):
274282
p.breakable()
275283
idx = None
276-
if hasattr(self, "__props__"):
277-
for idx, (name, item) in enumerate(zip(self.__props__, self.rands())):
284+
if hasattr(self, "__all_props__"):
285+
for idx, (name, item) in enumerate(zip(self.__all_props__, self.rands())):
278286
if idx:
279287
p.text(",")
280288
p.breakable()
@@ -292,8 +300,8 @@ def _repr_pretty_(self, p, cycle):
292300
p.pretty(obj)
293301

294302

295-
@dispatch((MetaSymbol, type(None), types.FunctionType, partial, str, dict))
296-
def _metatize(obj):
303+
@_metatize.register(MetaSymbol)
304+
def _metatize_MetaSymbol(obj):
297305
return obj
298306

299307

@@ -314,21 +322,26 @@ class MetaOp(MetaSymbol):
314322
def __init__(self, *args, **kwargs):
315323
super().__init__(*args, **kwargs)
316324

317-
@MetaSymbol.obj.setter
318-
def obj(self, x):
319-
if hasattr(self, "_obj"):
320-
raise ValueError("Cannot reset obj in an `Op`")
321-
object.__setattr__(self, "_obj", x)
322-
323325
@abc.abstractmethod
324326
def out_meta_types(self, inputs=None):
325327
"""Return the types of meta variables this `Op` is expected to produce given the inputs."""
326328
raise NotImplementedError()
327329

328330
@abc.abstractmethod
329-
def __call__(self, *args, ttype=None, index=None, **kwargs):
331+
def __call__(self, *args, **kwargs):
330332
raise NotImplementedError()
331333

334+
def __eq__(self, other):
335+
res = super().__eq__(other)
336+
337+
if res is NotImplemented:
338+
return getattr(self, "obj", None) == getattr(other, "obj", None) is not None
339+
340+
return res
341+
342+
def __hash__(self):
343+
return hash((self.base, self.obj))
344+
332345

333346
class MetaVariable(MetaSymbol):
334347
__slots__ = ()
@@ -369,14 +382,25 @@ def _find_meta_type(obj_type, meta_abs_type):
369382
# This object is a subclass of an existing meta class' base type,
370383
# but there is no implemented meta type for this subclass, so we
371384
# dynamically make one.
385+
386+
# FIXME, TODO: We should do something about `Op` constructor
387+
# arguments and properties.
388+
#
389+
# For instance, `tt.nlinalg.SVD` takes `full_matrices` and `compute_uv`
390+
# constructor arguments, but the dynamically constructed `TheanoMetaOp` type for
391+
# SVD is just the base `TheanoMetaOp.__init__`, which doesn't account for those.
392+
# To do this correctly, we would need to dynamically metatize the underlying
393+
# `Op`'s `__init__` and so on.
372394
new_type = type(f"Meta{obj_type.__name__}", (obj_cls,), {"base": obj_type})
373-
return new_type(obj_type)
395+
396+
return new_type
374397
else:
375398
cls = obj_cls
376399

377400

378-
@dispatch(type)
379-
def _metatize(obj_type):
401+
@_metatize.register(type)
402+
@cached(metatize_cache)
403+
def _metatize_type(obj_type):
380404
"""Return an existing meta type/class, or create a new one."""
381405
for meta_type in MetaSymbol.__subclasses__():
382406
obj_cls = _find_meta_type(obj_type, meta_type)

0 commit comments

Comments
 (0)