Skip to content

Commit 8433b9d

Browse files
Make better use of slots and cache hash values
1 parent 9743bcd commit 8433b9d

File tree

8 files changed

+250
-105
lines changed

8 files changed

+250
-105
lines changed

symbolic_pymc/meta.py

Lines changed: 72 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from itertools import chain
88
from functools import partial
9+
from collections import OrderedDict
910
from collections.abc import Iterator, Mapping
1011

1112
from unification import isvar, Var
@@ -84,33 +85,50 @@ def __new__(cls, name, bases, clsdict):
8485

8586
# We need to track the cumulative slots, because subclasses can define
8687
# their own--yet we'll need to track changes across all of them.
87-
all_slots = set(
88-
chain.from_iterable(s.__all_slots__ for s in bases if hasattr(s, "__all_slots__"))
88+
slots = clsdict.get("__slots__", ())
89+
all_slots = tuple(
90+
OrderedDict.fromkeys(
91+
chain(
92+
chain.from_iterable(
93+
tuple(s.__all_slots__) for s in bases if hasattr(s, "__all_slots__")
94+
),
95+
tuple(slots),
96+
)
97+
)
8998
)
90-
all_slots |= set(clsdict.get("__slots__", []))
99+
91100
clsdict["__all_slots__"] = all_slots
101+
clsdict["__all_props__"] = tuple(s for s in all_slots if not s.startswith("_"))
102+
clsdict["__volatile_slots__"] = tuple(s for s in all_slots if s.startswith("_"))
103+
clsdict["__props__"] = tuple(s for s in slots if not s.startswith("_"))
104+
105+
if clsdict["__volatile_slots__"]:
106+
107+
def __setattr__(self, attr, obj):
108+
"""If a slot value is changed, reset cached slots."""
92109

93-
def __setattr__(self, attr, obj):
94-
"""If a slot value is changed, discard any associated non-meta/base objects."""
95-
if attr == "obj":
96-
if isinstance(obj, MetaSymbol):
97-
raise ValueError("base object cannot be a meta object!")
98-
elif (
99-
getattr(self, "obj", None) is not None
100-
and not isinstance(self.obj, Var)
101-
and attr in getattr(self, "__all_slots__", {})
102-
and hasattr(self, attr)
103-
and getattr(self, attr) != obj
104-
):
105-
self.obj = None
110+
# Underscored-prefixed/volatile/stateful slots can be set
111+
# without affecting other such slots.
112+
if (
113+
attr not in self.__volatile_slots__
114+
# Are we trying to set a custom property?
115+
and attr in getattr(self, "__all_props__", ())
116+
# Is it a custom property that's already been set?
117+
and hasattr(self, attr)
118+
# Are we setting it to a new value?
119+
and getattr(self, attr) is not obj
120+
):
121+
for s in self.__volatile_slots__:
122+
object.__setattr__(self, s, None)
106123

107-
object.__setattr__(self, attr, obj)
124+
object.__setattr__(self, attr, obj)
108125

109-
clsdict["__setattr__"] = __setattr__
126+
clsdict["__setattr__"] = __setattr__
110127

111128
@classmethod
112129
def __metatize(cls, obj):
113-
return cls(*[getattr(obj, s) for s in getattr(cls, "__slots__", [])], obj=obj)
130+
"""Metatize using the `__all_props__` property."""
131+
return cls(*tuple(getattr(obj, s) for s in getattr(cls, "__all_props__", ())), obj=obj)
114132

115133
clsdict.setdefault("_metatize", __metatize)
116134

@@ -119,6 +137,22 @@ def __metatize(cls, obj):
119137
if isinstance(new_cls.base, type):
120138
_metatize.add((new_cls.base,), new_cls._metatize)
121139

140+
# Wrap the class implementation of `__hash__` with this value-caching
141+
# code.
142+
if "_hash" in clsdict["__volatile_slots__"]:
143+
_orig_hash = new_cls.__hash__
144+
new_cls._orig_hash = _orig_hash
145+
146+
def _cached_hash(self):
147+
if getattr(self, "_hash", None) is not None:
148+
return self._hash
149+
150+
object.__setattr__(self, "_hash", _orig_hash(self))
151+
152+
return self._hash
153+
154+
new_cls.__hash__ = _cached_hash
155+
122156
# TODO: Could register base classes.
123157
# E.g. cls.register(bases)
124158
return new_cls
@@ -130,12 +164,18 @@ class MetaSymbol(metaclass=MetaSymbolType):
130164
TODO: Should `MetaSymbol.obj` be an abstract property and a `weakref`?
131165
"""
132166

167+
__slots__ = ("_obj", "_hash")
168+
133169
@property
134170
@abc.abstractmethod
135171
def base(self):
136172
"""Return the underlying (e.g. a theano/tensorflow) base type/rator for this meta object."""
137173
raise NotImplementedError()
138174

175+
@property
176+
def obj(self):
177+
return object.__getattribute__(self, "_obj")
178+
139179
@classmethod
140180
def base_classes(cls, mro_order=True):
141181
res = tuple(c.base for c in cls.__subclasses__())
@@ -149,11 +189,11 @@ def is_meta(cls, obj):
149189
return isinstance(obj, MetaSymbol) or isvar(obj)
150190

151191
def __init__(self, obj=None):
152-
self.obj = obj
192+
self._obj = obj
153193

154194
def rands(self):
155195
"""Create a tuple of the meta object's operator parameters (i.e. "rands")."""
156-
return tuple(getattr(self, s) for s in getattr(self, "__slots__", []))
196+
return tuple(getattr(self, s) for s in getattr(self, "__all_props__", ()))
157197

158198
def reify(self):
159199
"""Create a concrete base object from this meta object (and its rands)."""
@@ -168,7 +208,7 @@ def reify(self):
168208
res = rator(*reified_rands)
169209

170210
if not any_unreified:
171-
self.obj = res
211+
self._obj = res
172212

173213
return res
174214

@@ -183,11 +223,11 @@ def __eq__(self, other):
183223
if not (self.base == other.base):
184224
return False
185225

186-
a_slots = getattr(self, "__slots__", None)
226+
a_slots = getattr(self, "__all_props__", None)
187227
if a_slots is not None:
188228
if not all(_check_eq(getattr(self, attr), getattr(other, attr)) for attr in a_slots):
189229
return False
190-
elif getattr(other, "__slots__", None) is not None:
230+
elif getattr(other, "__all_props__", None) is not None:
191231
# The other object has slots, but this one doesn't.
192232
return False
193233
else:
@@ -204,7 +244,7 @@ def __ne__(self, other):
204244
return not self.__eq__(other)
205245

206246
def __hash__(self):
207-
if getattr(self, "__slots__", None) is not None:
247+
if getattr(self, "__props__", None) is not None:
208248
rands = tuple(_make_hashable(p) for p in self.rands())
209249
return hash(rands + (self.base,))
210250
else:
@@ -233,8 +273,8 @@ def _repr_pretty_(self, p, cycle):
233273
with p.group(2, f"{self.__class__.__name__}(", ")"):
234274
p.breakable()
235275
idx = None
236-
if hasattr(self, "__slots__"):
237-
for idx, (name, item) in enumerate(zip(self.__slots__, self.rands())):
276+
if hasattr(self, "__props__"):
277+
for idx, (name, item) in enumerate(zip(self.__props__, self.rands())):
238278
if idx:
239279
p.text(",")
240280
p.breakable()
@@ -269,14 +309,12 @@ class MetaOp(MetaSymbol):
269309
implementation.
270310
"""
271311

312+
__slots__ = ()
313+
272314
def __init__(self, *args, **kwargs):
273315
super().__init__(*args, **kwargs)
274316

275-
@property
276-
def obj(self):
277-
return object.__getattribute__(self, "_obj")
278-
279-
@obj.setter
317+
@MetaSymbol.obj.setter
280318
def obj(self, x):
281319
if hasattr(self, "_obj"):
282320
raise ValueError("Cannot reset obj in an `Op`")
@@ -293,6 +331,8 @@ def __call__(self, *args, ttype=None, index=None, **kwargs):
293331

294332

295333
class MetaVariable(MetaSymbol):
334+
__slots__ = ()
335+
296336
@property
297337
@abc.abstractmethod
298338
def operator(self):

0 commit comments

Comments
 (0)