1- import traceback
21import warnings
32
43import numpy as np
@@ -30,7 +29,7 @@ def _get_vector_length_TensorSharedVariable(var_inst, var):
3029 return len (var .get_value (borrow = True ))
3130
3231
33- @shared_constructor
32+ @shared_constructor . register ( np . ndarray )
3433def tensor_constructor (
3534 value ,
3635 name = None ,
@@ -60,14 +59,13 @@ def tensor_constructor(
6059 if target != "cpu" :
6160 raise TypeError ("not for cpu" )
6261
63- if not isinstance (value , np .ndarray ):
64- raise TypeError ()
65-
6662 # If no shape is given, then the default is to assume that the value might
6763 # be resized in any dimension in the future.
6864 if shape is None :
69- shape = (None ,) * len (value .shape )
65+ shape = (None ,) * value .ndim
66+
7067 type = TensorType (value .dtype , shape = shape )
68+
7169 return TensorSharedVariable (
7270 type = type ,
7371 value = np .array (value , copy = (not borrow )),
@@ -81,7 +79,10 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
8179 pass
8280
8381
84- @shared_constructor
82+ @shared_constructor .register (np .number )
83+ @shared_constructor .register (float )
84+ @shared_constructor .register (int )
85+ @shared_constructor .register (complex )
8586def scalar_constructor (
8687 value , name = None , strict = False , allow_downcast = None , borrow = False , target = "cpu"
8788):
@@ -101,28 +102,22 @@ def scalar_constructor(
101102 if target != "cpu" :
102103 raise TypeError ("not for cpu" )
103104
104- if not isinstance (value , (np .number , float , int , complex )):
105- raise TypeError ()
106105 try :
107106 dtype = value .dtype
108- except Exception :
107+ except AttributeError :
109108 dtype = np .asarray (value ).dtype
110109
111110 dtype = str (dtype )
112111 value = _asarray (value , dtype = dtype )
113- tensor_type = TensorType (dtype = str (value .dtype ), shape = [] )
112+ tensor_type = TensorType (dtype = str (value .dtype ), shape = () )
114113
115- try :
116- # Do not pass the dtype to asarray because we want this to fail if
117- # strict is True and the types do not match.
118- rval = ScalarSharedVariable (
119- type = tensor_type ,
120- value = np .array (value , copy = True ),
121- name = name ,
122- strict = strict ,
123- allow_downcast = allow_downcast ,
124- )
125- return rval
126- except Exception :
127- traceback .print_exc ()
128- raise
114+ # Do not pass the dtype to asarray because we want this to fail if
115+ # strict is True and the types do not match.
116+ rval = ScalarSharedVariable (
117+ type = tensor_type ,
118+ value = np .array (value , copy = True ),
119+ name = name ,
120+ strict = strict ,
121+ allow_downcast = allow_downcast ,
122+ )
123+ return rval
0 commit comments