|
13 | 13 | import builtins |
14 | 14 | import math |
15 | 15 | from collections.abc import Callable |
| 16 | +from functools import lru_cache |
16 | 17 | from itertools import chain |
17 | 18 | from textwrap import dedent |
18 | 19 | from typing import Any, TypeAlias |
@@ -57,40 +58,51 @@ class IntegerDivisionError(Exception): |
57 | 58 | """ |
58 | 59 |
|
59 | 60 |
|
60 | | -def upcast(dtype, *dtypes) -> str: |
| 61 | +@lru_cache |
| 62 | +def _upcast_pairwise(dtype1, dtype2=None, *, cast_policy, floatX): |
61 | 63 | # This tries to keep data in floatX or lower precision, unless we |
62 | 64 | # explicitly request a higher precision datatype. |
63 | | - keep_float32 = [ |
64 | | - (config.cast_policy == "numpy+floatX" and config.floatX == "float32") |
65 | | - ] |
66 | | - keep_float16 = [ |
67 | | - (config.cast_policy == "numpy+floatX" and config.floatX == "float16") |
68 | | - ] |
69 | | - |
70 | | - def make_array(dt): |
71 | | - if dt == "float64": |
72 | | - # There is an explicit float64 dtype: we cannot keep float32. |
73 | | - keep_float32[0] = False |
74 | | - keep_float16[0] = False |
75 | | - if dt == "float32": |
76 | | - keep_float16[0] = False |
77 | | - return np.zeros((), dtype=dt) |
78 | | - |
79 | | - z = make_array(dtype) |
80 | | - for dt in dtypes: |
81 | | - z = z + make_array(dt=dt) |
82 | | - rval = str(z.dtype) |
| 65 | + if dtype1 == "float64": |
| 66 | + keep_float32, keep_float16 = False, False |
| 67 | + else: |
| 68 | + keep_float32 = cast_policy == "numpy+floatX" and floatX == "float32" |
| 69 | + keep_float16 = cast_policy == "numpy+floatX" and floatX == "float16" |
| 70 | + |
| 71 | + if dtype2 is not None: |
| 72 | + if dtype2 == "float64": |
| 73 | + keep_float32, keep_float16 = False, False |
| 74 | + elif dtype2 == "float32": |
| 75 | + keep_float16 = False |
| 76 | + |
| 77 | + if dtype2 is None: |
| 78 | + rval = dtype1 |
| 79 | + else: |
| 80 | + rval = (np.zeros((), dtype=dtype1) + np.zeros((), dtype=dtype2)).dtype.name |
| 81 | + |
83 | 82 | if rval == "float64": |
84 | | - if keep_float16[0]: |
| 83 | + if keep_float16: |
85 | 84 | return "float16" |
86 | | - if keep_float32[0]: |
| 85 | + if keep_float32: |
87 | 86 | return "float32" |
88 | 87 | elif rval == "float32": |
89 | | - if keep_float16[0]: |
| 88 | + if keep_float16: |
90 | 89 | return "float16" |
91 | 90 | return rval |
92 | 91 |
|
93 | 92 |
|
| 93 | +def upcast(dtype, *dtypes) -> str: |
| 94 | + # This tries to keep data in floatX or lower precision, unless we |
| 95 | + # explicitly request a higher precision datatype. |
| 96 | + floatX = config.floatX |
| 97 | + cast_policy = config.cast_policy |
| 98 | + res_dtype = _upcast_pairwise(dtype, cast_policy=cast_policy, floatX=floatX) |
| 99 | + for dt in dtypes: |
| 100 | + res_dtype = _upcast_pairwise( |
| 101 | + res_dtype, dt, cast_policy=cast_policy, floatX=floatX |
| 102 | + ) |
| 103 | + return res_dtype |
| 104 | + |
| 105 | + |
94 | 106 | def as_common_dtype(*vars): |
95 | 107 | """ |
96 | 108 | For for pytensor.scalar.ScalarType and TensorVariable. |
|
0 commit comments