Skip to content

Commit 30809e6

Browse files
committed
.faster tensortype creation
TODO: Do same for ScalarTypes
1 parent 3c26799 commit 30809e6

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

pytensor/tensor/type.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
int_dtypes = list(map(str, ps.int_types))
3636
uint_dtypes = list(map(str, ps.uint_types))
3737

38+
_all_dtypes_str: dict[str, str] = {d: d for d in all_dtypes}
39+
_str_to_numpy_dtype: dict[str, np.dtype] = {}
40+
3841
# TODO: add more type correspondences for e.g. int32, int64, float32,
3942
# complex64, etc.
4043
dtype_specs_map = {
@@ -99,13 +102,26 @@ def __init__(
99102
)
100103
shape = broadcastable
101104

102-
if str(dtype) == "floatX":
103-
self.dtype = config.floatX
105+
if isinstance(dtype, str):
106+
if dtype == "floatX":
107+
dtype = config.floatX
108+
elif dtype not in _all_dtypes_str:
109+
# Check if dtype is a valid numpy dtype
110+
try:
111+
dtype = str(np.dtype(dtype))
112+
except TypeError as exc:
113+
raise TypeError(
114+
f"Unsupported dtype for TensorType: {dtype}"
115+
) from exc
116+
else:
117+
_all_dtypes_str[dtype] = dtype
104118
else:
105119
try:
106-
self.dtype = str(np.dtype(dtype))
107-
except TypeError:
108-
raise TypeError(f"Invalid dtype: {dtype}")
120+
dtype = str(np.dtype(dtype))
121+
except TypeError as exc:
122+
raise TypeError(f"Unsupported dtype for TensorType: {dtype}") from exc
123+
124+
self.dtype = dtype
109125

110126
def parse_bcast_and_shape(s):
111127
if isinstance(s, bool | np.bool_):
@@ -121,7 +137,10 @@ def parse_bcast_and_shape(s):
121137
self.shape = tuple(parse_bcast_and_shape(s) for s in shape)
122138
self.dtype_specs() # error checking is done there
123139
self.name = name
124-
self.numpy_dtype = np.dtype(self.dtype)
140+
try:
141+
self.numpy_dtype = _str_to_numpy_dtype[dtype]
142+
except KeyError:
143+
self.numpy_dtype = _str_to_numpy_dtype[dtype] = np.dtype(dtype)
125144

126145
def __call__(self, *args, shape=None, **kwargs):
127146
if shape is not None:

0 commit comments

Comments
 (0)