|
23 | 23 | _metatize, |
24 | 24 | ) |
25 | 25 |
|
| 26 | +from .. import meta |
| 27 | + |
| 28 | +from ..utils import HashableNDArray |
| 29 | + |
26 | 30 |
|
27 | 31 | def _metatize_theano_object(obj): |
28 | 32 | try: |
29 | 33 | obj = tt.as_tensor_variable(obj) |
30 | 34 | except (ValueError, tt.AsTensorError): |
31 | | - raise ValueError("Could not find a MetaSymbol class for {}".format(obj)) |
| 35 | + raise ValueError("Error converting {} to a Theano tensor.".format(obj)) |
| 36 | + except AssertionError: |
| 37 | + # This is a work-around for a Theano bug; specifically, |
| 38 | + # an assert statement in `theano.scalar.basic` that unnecessarily |
| 39 | + # requires the object type be exclusively an ndarray or memmap. |
| 40 | + # See https://github.com/Theano/Theano/pull/6727 |
| 41 | + obj = tt.as_tensor_variable(np.asarray(obj)) |
| 42 | + |
32 | 43 | return _metatize(obj) |
33 | 44 |
|
34 | 45 |
|
35 | 46 | def load_dispatcher(): |
36 | 47 | """Set/override dispatcher to default to TF objects.""" |
37 | | - _metatize.add((object,), _metatize_theano_object) |
| 48 | + meta._metatize.add((object,), _metatize_theano_object) |
| 49 | + meta._metatize.add((HashableNDArray,), _metatize_theano_object) |
38 | 50 |
|
| 51 | + for new_cls in TheanoMetaSymbol.base_subclasses(): |
| 52 | + meta._metatize.add((new_cls.base,), new_cls._metatize) |
39 | 53 |
|
40 | | -load_dispatcher() |
| 54 | + return meta._metatize |
41 | 55 |
|
42 | 56 |
|
43 | 57 | class TheanoMetaSymbol(MetaSymbol): |
@@ -202,7 +216,7 @@ def __call__(self, *args, ttype=None, index=None, **kwargs): |
202 | 216 | # XXX: We don't have a higher-order meta object model, so being |
203 | 217 | # wrong about the exact type of output variable will cause |
204 | 218 | # problems. |
205 | | - out_meta_type, = self.out_meta_types(op_args) |
| 219 | + (out_meta_type,) = self.out_meta_types(op_args) |
206 | 220 | res_var = out_meta_type(ttype, res_apply, index, name) |
207 | 221 | res_var._obj = var() |
208 | 222 |
|
@@ -451,28 +465,9 @@ def _metatize(cls, obj): |
451 | 465 | return res |
452 | 466 |
|
453 | 467 | def __init__(self, type, data, name=None, obj=None): |
454 | | - self.data = data |
| 468 | + self.data = data if not isinstance(data, np.ndarray) else data.view(HashableNDArray) |
455 | 469 | super().__init__(type, None, None, name, obj=obj) |
456 | 470 |
|
457 | | - def __eq__(self, other): |
458 | | - if self is other: |
459 | | - return True |
460 | | - |
461 | | - if type(self) != type(other): |
462 | | - return False |
463 | | - |
464 | | - if all( |
465 | | - (s.tostring() if isinstance(s, np.ndarray) else s) |
466 | | - == (o.tostring() if isinstance(o, np.ndarray) else o) |
467 | | - for s, o in zip(self.rands(), other.rands()) |
468 | | - ): |
469 | | - return True |
470 | | - |
471 | | - return False |
472 | | - |
473 | | - def __hash__(self): |
474 | | - return hash(v.tostring() if isinstance(v, np.ndarray) else v for v in self.rands()) |
475 | | - |
476 | 471 |
|
477 | 472 | class TheanoMetaTensorConstant(TheanoMetaConstant): |
478 | 473 | # TODO: Could extend `theano.tensor.var._tensor_py_operators`, too. |
@@ -606,7 +601,9 @@ def meta_obj(*args, **kwargs): |
606 | 601 |
|
607 | 602 | mt = TheanoMetaAccessor() |
608 | 603 |
|
609 | | -mt.dot = metatize(tt.basic._dot) |
| 604 | +_metatize = load_dispatcher() |
| 605 | + |
| 606 | +mt.dot = _metatize(tt.basic._dot) |
610 | 607 |
|
611 | 608 |
|
612 | 609 | # |
|
0 commit comments