66
77from itertools import chain
88from functools import partial
9+ from collections import OrderedDict
910from collections .abc import Iterator , Mapping
1011
1112from unification import isvar , Var
@@ -84,33 +85,50 @@ def __new__(cls, name, bases, clsdict):
8485
8586 # We need to track the cumulative slots, because subclasses can define
8687 # their own--yet we'll need to track changes across all of them.
87- all_slots = set (
88- chain .from_iterable (s .__all_slots__ for s in bases if hasattr (s , "__all_slots__" ))
88+ slots = clsdict .get ("__slots__" , ())
89+ all_slots = tuple (
90+ OrderedDict .fromkeys (
91+ chain (
92+ chain .from_iterable (
93+ tuple (s .__all_slots__ ) for s in bases if hasattr (s , "__all_slots__" )
94+ ),
95+ tuple (slots ),
96+ )
97+ )
8998 )
90- all_slots |= set ( clsdict . get ( "__slots__" , []))
99+
91100 clsdict ["__all_slots__" ] = all_slots
101+ clsdict ["__all_props__" ] = tuple (s for s in all_slots if not s .startswith ("_" ))
102+ clsdict ["__volatile_slots__" ] = tuple (s for s in all_slots if s .startswith ("_" ))
103+ clsdict ["__props__" ] = tuple (s for s in slots if not s .startswith ("_" ))
104+
105+ if clsdict ["__volatile_slots__" ]:
106+
107+ def __setattr__ (self , attr , obj ):
108+ """If a slot value is changed, reset cached slots."""
92109
93- def __setattr__ ( self , attr , obj ):
94- """If a slot value is changed, discard any associated non-meta/base objects."""
95- if attr == "obj" :
96- if isinstance ( obj , MetaSymbol ):
97- raise ValueError ( "base object cannot be a meta object!" )
98- elif (
99- getattr ( self , "obj" , None ) is not None
100- and not isinstance (self . obj , Var )
101- and attr in getattr ( self , "__all_slots__" , {})
102- and hasattr (self , attr )
103- and getattr ( self , attr ) != obj
104- ) :
105- self . obj = None
110+ # Underscored-prefixed/volatile/stateful slots can be set
111+ # without affecting other such slots.
112+ if (
113+ attr not in self . __volatile_slots__
114+ # Are we trying to set a custom property?
115+ and attr in getattr ( self , "__all_props__" , ())
116+ # Is it a custom property that's already been set?
117+ and hasattr (self , attr )
118+ # Are we setting it to a new value?
119+ and getattr (self , attr ) is not obj
120+ ):
121+ for s in self . __volatile_slots__ :
122+ object . __setattr__ ( self , s , None )
106123
107- object .__setattr__ (self , attr , obj )
124+ object .__setattr__ (self , attr , obj )
108125
109- clsdict ["__setattr__" ] = __setattr__
126+ clsdict ["__setattr__" ] = __setattr__
110127
111128 @classmethod
112129 def __metatize (cls , obj ):
113- return cls (* [getattr (obj , s ) for s in getattr (cls , "__slots__" , [])], obj = obj )
130+ """Metatize using the `__all_props__` property."""
131+ return cls (* tuple (getattr (obj , s ) for s in getattr (cls , "__all_props__" , ())), obj = obj )
114132
115133 clsdict .setdefault ("_metatize" , __metatize )
116134
@@ -119,6 +137,22 @@ def __metatize(cls, obj):
119137 if isinstance (new_cls .base , type ):
120138 _metatize .add ((new_cls .base ,), new_cls ._metatize )
121139
140+ # Wrap the class implementation of `__hash__` with this value-caching
141+ # code.
142+ if "_hash" in clsdict ["__volatile_slots__" ]:
143+ _orig_hash = new_cls .__hash__
144+ new_cls ._orig_hash = _orig_hash
145+
146+ def _cached_hash (self ):
147+ if getattr (self , "_hash" , None ) is not None :
148+ return self ._hash
149+
150+ object .__setattr__ (self , "_hash" , _orig_hash (self ))
151+
152+ return self ._hash
153+
154+ new_cls .__hash__ = _cached_hash
155+
122156 # TODO: Could register base classes.
123157 # E.g. cls.register(bases)
124158 return new_cls
@@ -130,12 +164,18 @@ class MetaSymbol(metaclass=MetaSymbolType):
130164 TODO: Should `MetaSymbol.obj` be an abstract property and a `weakref`?
131165 """
132166
167+ __slots__ = ("_obj" , "_hash" )
168+
133169 @property
134170 @abc .abstractmethod
135171 def base (self ):
136172 """Return the underlying (e.g. a theano/tensorflow) base type/rator for this meta object."""
137173 raise NotImplementedError ()
138174
175+ @property
176+ def obj (self ):
177+ return object .__getattribute__ (self , "_obj" )
178+
139179 @classmethod
140180 def base_classes (cls , mro_order = True ):
141181 res = tuple (c .base for c in cls .__subclasses__ ())
@@ -149,11 +189,11 @@ def is_meta(cls, obj):
149189 return isinstance (obj , MetaSymbol ) or isvar (obj )
150190
151191 def __init__ (self , obj = None ):
152- self .obj = obj
192+ self ._obj = obj
153193
154194 def rands (self ):
155195 """Create a tuple of the meta object's operator parameters (i.e. "rands")."""
156- return tuple (getattr (self , s ) for s in getattr (self , "__slots__ " , [] ))
196+ return tuple (getattr (self , s ) for s in getattr (self , "__all_props__ " , () ))
157197
158198 def reify (self ):
159199 """Create a concrete base object from this meta object (and its rands)."""
@@ -168,7 +208,7 @@ def reify(self):
168208 res = rator (* reified_rands )
169209
170210 if not any_unreified :
171- self .obj = res
211+ self ._obj = res
172212
173213 return res
174214
@@ -183,11 +223,11 @@ def __eq__(self, other):
183223 if not (self .base == other .base ):
184224 return False
185225
186- a_slots = getattr (self , "__slots__ " , None )
226+ a_slots = getattr (self , "__all_props__ " , None )
187227 if a_slots is not None :
188228 if not all (_check_eq (getattr (self , attr ), getattr (other , attr )) for attr in a_slots ):
189229 return False
190- elif getattr (other , "__slots__ " , None ) is not None :
230+ elif getattr (other , "__all_props__ " , None ) is not None :
191231 # The other object has slots, but this one doesn't.
192232 return False
193233 else :
@@ -204,7 +244,7 @@ def __ne__(self, other):
204244 return not self .__eq__ (other )
205245
206246 def __hash__ (self ):
207- if getattr (self , "__slots__ " , None ) is not None :
247+ if getattr (self , "__props__ " , None ) is not None :
208248 rands = tuple (_make_hashable (p ) for p in self .rands ())
209249 return hash (rands + (self .base ,))
210250 else :
@@ -233,8 +273,8 @@ def _repr_pretty_(self, p, cycle):
233273 with p .group (2 , f"{ self .__class__ .__name__ } (" , ")" ):
234274 p .breakable ()
235275 idx = None
236- if hasattr (self , "__slots__ " ):
237- for idx , (name , item ) in enumerate (zip (self .__slots__ , self .rands ())):
276+ if hasattr (self , "__props__ " ):
277+ for idx , (name , item ) in enumerate (zip (self .__props__ , self .rands ())):
238278 if idx :
239279 p .text ("," )
240280 p .breakable ()
@@ -269,14 +309,12 @@ class MetaOp(MetaSymbol):
269309 implementation.
270310 """
271311
312+ __slots__ = ()
313+
272314 def __init__ (self , * args , ** kwargs ):
273315 super ().__init__ (* args , ** kwargs )
274316
275- @property
276- def obj (self ):
277- return object .__getattribute__ (self , "_obj" )
278-
279- @obj .setter
317+ @MetaSymbol .obj .setter
280318 def obj (self , x ):
281319 if hasattr (self , "_obj" ):
282320 raise ValueError ("Cannot reset obj in an `Op`" )
@@ -293,6 +331,8 @@ def __call__(self, *args, ttype=None, index=None, **kwargs):
293331
294332
295333class MetaVariable (MetaSymbol ):
334+ __slots__ = ()
335+
296336 @property
297337 @abc .abstractmethod
298338 def operator (self ):
0 commit comments