22import types
33import reprlib
44
5- import numpy as np
6-
75from itertools import chain
86from functools import partial
97from collections import OrderedDict
1513
1614from multipledispatch import dispatch
1715
16+ from cachetools import cached
17+
1818meta_repr = reprlib .Repr ()
1919meta_repr .maxstring = 100
2020meta_repr .maxother = 100
2121meta_repr .print_obj = False
2222
23+ metatize_cache = {}
24+
2325
2426def metatize (obj ):
2527 """Convert object to base type then meta object."""
@@ -28,38 +30,39 @@ def metatize(obj):
2830 return _metatize (obj )
2931
3032
31- @dispatch ((set , list , tuple ))
33+ @dispatch ((type ( None ), types . FunctionType , partial , str , dict ))
3234def _metatize (obj ):
35+ return obj
36+
37+
38+ @_metatize .register ((set , tuple ))
39+ @cached (metatize_cache )
40+ def _metatize_set_tuple (obj ):
3341 """Convert elements of an iterable to meta objects."""
3442 return type (obj )([metatize (o ) for o in obj ])
3543
3644
37- @dispatch ( Iterator )
38- def _metatize (obj ):
39- """Convert elements of an iterator to meta objects."""
40- return iter ([metatize (o ) for o in obj ])
45+ @_metatize . register ( list )
46+ def _metatize_list (obj ):
47+ """Convert elements of an iterable to meta objects."""
48+ return type ( obj ) ([metatize (o ) for o in obj ])
4149
4250
43- def _make_hashable (x ):
44- if isinstance (x , list ):
45- return tuple (x )
46- elif isinstance (x , Mapping ):
47- return frozenset (x .items ())
48- elif isinstance (x , np .ndarray ):
49- return x .tostring ()
50- else :
51- return x
51+ @_metatize .register (Iterator )
52+ @cached (metatize_cache )
53+ def _metatize_Iterator (obj ):
54+ """Convert elements of an iterator to meta objects."""
55+ return iter ([metatize (o ) for o in obj ])
5256
5357
54- def _meta_reify_iter (rands ):
58+ def meta_reify_iter (rands ):
5559 """Recursively reify an iterable object and return a boolean indicating the presence of un-reifiable objects, if any."""
56- # We want as many of the rands reified as possible,
5760 any_unreified = False
5861 reified_rands = []
59- if isinstance ( rands , Mapping ):
60- _rands = rands . items ()
61- else :
62- _rands = rands
62+
63+ _rands = rands
64+ if isinstance ( _rands , Mapping ) :
65+ _rands = _rands . items ()
6366
6467 for s in _rands :
6568 if isinstance (s , MetaSymbol ):
@@ -71,11 +74,11 @@ def _meta_reify_iter(rands):
7174 reified_rands .append (s )
7275 any_unreified |= True
7376 elif isinstance (s , (list , tuple )):
74- _reified_rands , _any_unreified = _meta_reify_iter (s )
77+ _reified_rands , _any_unreified = meta_reify_iter (s )
7578 reified_rands .append (type (s )(_reified_rands ))
7679 any_unreified |= _any_unreified
7780 else :
78- reified_rands += [ s ]
81+ reified_rands . append ( s )
7982
8083 return type (rands )(reified_rands ), any_unreified
8184
@@ -153,8 +156,6 @@ def _cached_hash(self):
153156
154157 new_cls .__hash__ = _cached_hash
155158
156- # TODO: Could register base classes.
157- # E.g. cls.register(bases)
158159 return new_cls
159160
160161
@@ -164,7 +165,7 @@ class MetaSymbol(metaclass=MetaSymbolType):
164165 TODO: Should `MetaSymbol.obj` be an abstract property and a `weakref`?
165166 """
166167
167- __slots__ = ("_obj" , "_hash" )
168+ __slots__ = ("_obj" , "_hash" , "_rands" )
168169
169170 @property
170171 @abc .abstractmethod
@@ -189,18 +190,40 @@ def is_meta(cls, obj):
189190 return isinstance (obj , MetaSymbol ) or isvar (obj )
190191
191192 def __init__ (self , obj = None ):
193+ assert obj is None or isvar (obj ) or isinstance (obj , self .base )
192194 self ._obj = obj
193195
194196 def rands (self ):
195- """Create a tuple of the meta object's operator parameters (i.e. "rands")."""
196- return tuple (getattr (self , s ) for s in getattr (self , "__all_props__" , ()))
197+ """Get a tuple of the meta object's operator parameters (i.e. "rands")."""
198+ if getattr (self , "_rands" , None ) is not None :
199+ return self ._rands
200+
201+ self ._rands = tuple (getattr (self , s ) for s in getattr (self , "__all_props__" , ()))
202+
203+ return self ._rands
197204
198205 def reify (self ):
199- """Create a concrete base object from this meta object (and its rands)."""
206+ """Attempt to create a concrete base object from this meta object.
207+
208+ During the process, dependent objects will need to be reified, which
209+ may result in updates to the object(s) being reified.
210+
211+ For instance, if a meta tensor's parent operator is fully reifiable to
212+ a base object, then the meta tensor's dtype and shape may be fixed:
213+ e.g. a tensor corresponding to the output of a sum of two float64
214+ scalars is necessarily a float64 scalar.
215+
216+ This function will set any unspecified properties (e.g. dtype and shape
217+ values for the previous example), mutating the object in-place when
218+ possible. It will return a [refined/partially reified] meta object
219+ when it can't fully reify to a base object (in which case, it will
220+ return the base object) or when partial reification results in a meta
221+ object from a subclass.
222+ """
200223 if self .obj is not None and not isinstance (self .obj , Var ):
201224 return self .obj
202225 else :
203- reified_rands , any_unreified = _meta_reify_iter (self .rands ())
226+ reified_rands , any_unreified = meta_reify_iter (self .rands ())
204227
205228 # If not all the rands reified, then create another meta
206229 # object--albeit one with potentially more non-`None` `obj` fields.
@@ -220,35 +243,20 @@ def __eq__(self, other):
220243 if not (type (self ) == type (other )):
221244 return False
222245
223- if not (self .base == other .base ):
224- return False
246+ assert self .base == other .base
225247
226- a_slots = getattr (self , "__all_props__" , None )
227- if a_slots is not None :
228- if not all (_check_eq (getattr (self , attr ), getattr (other , attr )) for attr in a_slots ):
229- return False
230- elif getattr (other , "__all_props__" , None ) is not None :
231- # The other object has slots, but this one doesn't.
232- return False
248+ if self .rands ():
249+ return all (_check_eq (s , o ) for s , o in zip (self .rands (), other .rands ()))
233250 else :
234- # Neither have slots, so best we can do is compare
235- # base objects (if any).
236- # If there aren't base objects, we say they're not equal.
237- # (Maybe we should *require* base objects in this case
238- # and raise an exception?)
239- return getattr (self , "obj" , None ) == getattr (other , "obj" , None ) is not None
251+ return NotImplemented
240252
241- return True
253+ return False
242254
243255 def __ne__ (self , other ):
244256 return not self .__eq__ (other )
245257
246258 def __hash__ (self ):
247- if getattr (self , "__props__" , None ) is not None :
248- rands = tuple (_make_hashable (p ) for p in self .rands ())
249- return hash (rands + (self .base ,))
250- else :
251- return hash ((self .base , self .obj ))
259+ return hash ((self .base , self .rands ()))
252260
253261 def __str__ (self ):
254262 obj = getattr (self , "obj" , None )
@@ -273,8 +281,8 @@ def _repr_pretty_(self, p, cycle):
273281 with p .group (2 , f"{ self .__class__ .__name__ } (" , ")" ):
274282 p .breakable ()
275283 idx = None
276- if hasattr (self , "__props__ " ):
277- for idx , (name , item ) in enumerate (zip (self .__props__ , self .rands ())):
284+ if hasattr (self , "__all_props__ " ):
285+ for idx , (name , item ) in enumerate (zip (self .__all_props__ , self .rands ())):
278286 if idx :
279287 p .text ("," )
280288 p .breakable ()
@@ -292,8 +300,8 @@ def _repr_pretty_(self, p, cycle):
292300 p .pretty (obj )
293301
294302
295- @dispatch (( MetaSymbol , type ( None ), types . FunctionType , partial , str , dict ) )
296- def _metatize (obj ):
303+ @_metatize . register ( MetaSymbol )
304+ def _metatize_MetaSymbol (obj ):
297305 return obj
298306
299307
@@ -314,21 +322,26 @@ class MetaOp(MetaSymbol):
314322 def __init__ (self , * args , ** kwargs ):
315323 super ().__init__ (* args , ** kwargs )
316324
317- @MetaSymbol .obj .setter
318- def obj (self , x ):
319- if hasattr (self , "_obj" ):
320- raise ValueError ("Cannot reset obj in an `Op`" )
321- object .__setattr__ (self , "_obj" , x )
322-
323325 @abc .abstractmethod
324326 def out_meta_types (self , inputs = None ):
325327 """Return the types of meta variables this `Op` is expected to produce given the inputs."""
326328 raise NotImplementedError ()
327329
328330 @abc .abstractmethod
329- def __call__ (self , * args , ttype = None , index = None , ** kwargs ):
331+ def __call__ (self , * args , ** kwargs ):
330332 raise NotImplementedError ()
331333
334+ def __eq__ (self , other ):
335+ res = super ().__eq__ (other )
336+
337+ if res is NotImplemented :
338+ return getattr (self , "obj" , None ) == getattr (other , "obj" , None ) is not None
339+
340+ return res
341+
342+ def __hash__ (self ):
343+ return hash ((self .base , self .obj ))
344+
332345
333346class MetaVariable (MetaSymbol ):
334347 __slots__ = ()
@@ -369,14 +382,25 @@ def _find_meta_type(obj_type, meta_abs_type):
369382 # This object is a subclass of an existing meta class' base type,
370383 # but there is no implemented meta type for this subclass, so we
371384 # dynamically make one.
385+
386+ # FIXME, TODO: We should do something about `Op` constructor
387+ # arguments and properties.
388+ #
389+ # For instance, `tt.nlinalg.SVD` takes `full_matrices` and `compute_uv`
390+ # constructor arguments, but the dynamically constructed `TheanoMetaOp` type for
391+ # SVD is just the base `TheanoMetaOp.__init__`, which doesn't account for those.
392+ # To do this correctly, we would need to dynamically metatize the underlying
393+ # `Op`'s `__init__` and so on.
372394 new_type = type (f"Meta{ obj_type .__name__ } " , (obj_cls ,), {"base" : obj_type })
373- return new_type (obj_type )
395+
396+ return new_type
374397 else :
375398 cls = obj_cls
376399
377400
378- @dispatch (type )
379- def _metatize (obj_type ):
401+ @_metatize .register (type )
402+ @cached (metatize_cache )
403+ def _metatize_type (obj_type ):
380404 """Return an existing meta type/class, or create a new one."""
381405 for meta_type in MetaSymbol .__subclasses__ ():
382406 obj_cls = _find_meta_type (obj_type , meta_type )
0 commit comments