3535int_dtypes = list (map (str , ps .int_types ))
3636uint_dtypes = list (map (str , ps .uint_types ))
3737
38+ _all_dtypes_str : dict [str , str ] = {d : d for d in all_dtypes }
39+ _str_to_numpy_dtype : dict [str , np .dtype ] = {}
40+
3841# TODO: add more type correspondences for e.g. int32, int64, float32,
3942# complex64, etc.
4043dtype_specs_map = {
@@ -99,13 +102,26 @@ def __init__(
99102 )
100103 shape = broadcastable
101104
102- if str (dtype ) == "floatX" :
103- self .dtype = config .floatX
105+ if isinstance (dtype , str ):
106+ if dtype == "floatX" :
107+ dtype = config .floatX
108+ elif dtype not in _all_dtypes_str :
109+ # Check if dtype is a valid numpy dtype
110+ try :
111+ dtype = str (np .dtype (dtype ))
112+ except TypeError as exc :
113+ raise TypeError (
114+ f"Unsupported dtype for TensorType: { dtype } "
115+ ) from exc
116+ else :
117+ _all_dtypes_str [dtype ] = dtype
104118 else :
105119 try :
106- self .dtype = str (np .dtype (dtype ))
107- except TypeError :
108- raise TypeError (f"Invalid dtype: { dtype } " )
120+ dtype = str (np .dtype (dtype ))
121+ except TypeError as exc :
122+ raise TypeError (f"Unsupported dtype for TensorType: { dtype } " ) from exc
123+
124+ self .dtype = dtype
109125
110126 def parse_bcast_and_shape (s ):
111127 if isinstance (s , bool | np .bool_ ):
@@ -121,7 +137,10 @@ def parse_bcast_and_shape(s):
121137 self .shape = tuple (parse_bcast_and_shape (s ) for s in shape )
122138 self .dtype_specs () # error checking is done there
123139 self .name = name
124- self .numpy_dtype = np .dtype (self .dtype )
140+ try :
141+ self .numpy_dtype = _str_to_numpy_dtype [dtype ]
142+ except KeyError :
143+ self .numpy_dtype = _str_to_numpy_dtype [dtype ] = np .dtype (dtype )
125144
126145 def __call__ (self , * args , shape = None , ** kwargs ):
127146 if shape is not None :
0 commit comments