Skip to content

Commit f385b0d

Browse files
committed
cache _upcast_impl
1 parent aadbdfe commit f385b0d

File tree

1 file changed

+36
-24
lines changed

1 file changed

+36
-24
lines changed

pytensor/scalar/basic.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import builtins
1414
import math
1515
from collections.abc import Callable
16+
from functools import lru_cache
1617
from itertools import chain
1718
from textwrap import dedent
1819
from typing import Any, TypeAlias
@@ -57,40 +58,51 @@ class IntegerDivisionError(Exception):
5758
"""
5859

5960

60-
def upcast(dtype, *dtypes) -> str:
61+
@lru_cache
62+
def _upcast_pairwise(dtype1, dtype2=None, *, cast_policy, floatX):
6163
# This tries to keep data in floatX or lower precision, unless we
6264
# explicitly request a higher precision datatype.
63-
keep_float32 = [
64-
(config.cast_policy == "numpy+floatX" and config.floatX == "float32")
65-
]
66-
keep_float16 = [
67-
(config.cast_policy == "numpy+floatX" and config.floatX == "float16")
68-
]
69-
70-
def make_array(dt):
71-
if dt == "float64":
72-
# There is an explicit float64 dtype: we cannot keep float32.
73-
keep_float32[0] = False
74-
keep_float16[0] = False
75-
if dt == "float32":
76-
keep_float16[0] = False
77-
return np.zeros((), dtype=dt)
78-
79-
z = make_array(dtype)
80-
for dt in dtypes:
81-
z = z + make_array(dt=dt)
82-
rval = str(z.dtype)
65+
if dtype1 == "float64":
66+
keep_float32, keep_float16 = False, False
67+
else:
68+
keep_float32 = cast_policy == "numpy+floatX" and floatX == "float32"
69+
keep_float16 = cast_policy == "numpy+floatX" and floatX == "float16"
70+
71+
if dtype2 is not None:
72+
if dtype2 == "float64":
73+
keep_float32, keep_float16 = False, False
74+
elif dtype2 == "float32":
75+
keep_float16 = False
76+
77+
if dtype2 is None:
78+
rval = dtype1
79+
else:
80+
rval = (np.zeros((), dtype=dtype1) + np.zeros((), dtype=dtype2)).dtype.name
81+
8382
if rval == "float64":
84-
if keep_float16[0]:
83+
if keep_float16:
8584
return "float16"
86-
if keep_float32[0]:
85+
if keep_float32:
8786
return "float32"
8887
elif rval == "float32":
89-
if keep_float16[0]:
88+
if keep_float16:
9089
return "float16"
9190
return rval
9291

9392

93+
def upcast(dtype, *dtypes) -> str:
94+
# This tries to keep data in floatX or lower precision, unless we
95+
# explicitly request a higher precision datatype.
96+
floatX = config.floatX
97+
cast_policy = config.cast_policy
98+
res_dtype = _upcast_pairwise(dtype, cast_policy=cast_policy, floatX=floatX)
99+
for dt in dtypes:
100+
res_dtype = _upcast_pairwise(
101+
res_dtype, dt, cast_policy=cast_policy, floatX=floatX
102+
)
103+
return res_dtype
104+
105+
94106
def as_common_dtype(*vars):
95107
"""
96108
For for pytensor.scalar.ScalarType and TensorVariable.

0 commit comments

Comments
 (0)