|
37 | 37 | ] |
38 | 38 |
|
39 | 39 |
|
| 40 | +class EqualityMapping(Mapping): |
| 41 | + """ |
| 42 | + Mapping that uses equality for indexing |
| 43 | +
|
| 44 | + Typical mappings (e.g. the built-in dict) use hashing for indexing. This |
| 45 | + isn't ideal for the Array API, as no __hash__() method is specified for |
| 46 | + dtype objects - but __eq__() is! |
| 47 | +
|
| 48 | + See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects |
| 49 | + """ |
| 50 | + |
| 51 | + def __init__(self, mapping: Mapping): |
| 52 | + keys = list(mapping.keys()) |
| 53 | + for i, key in enumerate(keys): |
| 54 | + if not (key == key): # specifically checking __eq__, not __neq__ |
| 55 | + raise ValueError("Key {key!r} does not have equality with itself") |
| 56 | + other_keys = keys[:] |
| 57 | + other_keys.pop(i) |
| 58 | + for other_key in other_keys: |
| 59 | + if key == other_key: |
| 60 | + raise ValueError("Key {key!r} has equality with key {other_key!r}") |
| 61 | + self._mapping = mapping |
| 62 | + |
| 63 | + def __getitem__(self, key): |
| 64 | + for k, v in self._mapping.items(): |
| 65 | + if key == k: |
| 66 | + return v |
| 67 | + else: |
| 68 | + raise KeyError(f"{key!r} not found") |
| 69 | + |
| 70 | + def __iter__(self): |
| 71 | + return iter(self._mapping) |
| 72 | + |
| 73 | + def __len__(self): |
| 74 | + return len(self._mapping) |
| 75 | + |
| 76 | + def __repr__(self): |
| 77 | + return f"EqualityMapping({self._mapping!r})" |
| 78 | + |
| 79 | + |
40 | 80 | _uint_names = ("uint8", "uint16", "uint32", "uint64") |
41 | 81 | _int_names = ("int8", "int16", "int32", "int64") |
42 | 82 | _float_names = ("float32", "float64") |
|
52 | 92 | bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes |
53 | 93 |
|
54 | 94 |
|
55 | | -dtype_to_name = {getattr(xp, name): name for name in _dtype_names} |
| 95 | +dtype_to_name = EqualityMapping({getattr(xp, name): name for name in _dtype_names}) |
56 | 96 |
|
57 | 97 |
|
58 | | -dtype_to_scalars = { |
59 | | - xp.bool: [bool], |
60 | | - **{d: [int] for d in all_int_dtypes}, |
61 | | - **{d: [int, float] for d in float_dtypes}, |
62 | | -} |
| 98 | +dtype_to_scalars = EqualityMapping( |
| 99 | + { |
| 100 | + xp.bool: [bool], |
| 101 | + **{d: [int] for d in all_int_dtypes}, |
| 102 | + **{d: [int, float] for d in float_dtypes}, |
| 103 | + } |
| 104 | +) |
63 | 105 |
|
64 | 106 |
|
65 | 107 | def is_int_dtype(dtype): |
@@ -91,31 +133,37 @@ class MinMax(NamedTuple): |
91 | 133 | max: Union[int, float] |
92 | 134 |
|
93 | 135 |
|
94 | | -dtype_ranges = { |
95 | | - xp.int8: MinMax(-128, +127), |
96 | | - xp.int16: MinMax(-32_768, +32_767), |
97 | | - xp.int32: MinMax(-2_147_483_648, +2_147_483_647), |
98 | | - xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), |
99 | | - xp.uint8: MinMax(0, +255), |
100 | | - xp.uint16: MinMax(0, +65_535), |
101 | | - xp.uint32: MinMax(0, +4_294_967_295), |
102 | | - xp.uint64: MinMax(0, +18_446_744_073_709_551_615), |
103 | | - xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38), |
104 | | - xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308), |
105 | | -} |
106 | | - |
107 | | -dtype_nbits = { |
108 | | - **{d: 8 for d in [xp.int8, xp.uint8]}, |
109 | | - **{d: 16 for d in [xp.int16, xp.uint16]}, |
110 | | - **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, |
111 | | - **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, |
112 | | -} |
113 | | - |
114 | | - |
115 | | -dtype_signed = { |
116 | | - **{d: True for d in int_dtypes}, |
117 | | - **{d: False for d in uint_dtypes}, |
118 | | -} |
| 136 | +dtype_ranges = EqualityMapping( |
| 137 | + { |
| 138 | + xp.int8: MinMax(-128, +127), |
| 139 | + xp.int16: MinMax(-32_768, +32_767), |
| 140 | + xp.int32: MinMax(-2_147_483_648, +2_147_483_647), |
| 141 | + xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), |
| 142 | + xp.uint8: MinMax(0, +255), |
| 143 | + xp.uint16: MinMax(0, +65_535), |
| 144 | + xp.uint32: MinMax(0, +4_294_967_295), |
| 145 | + xp.uint64: MinMax(0, +18_446_744_073_709_551_615), |
| 146 | + xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38), |
| 147 | + xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308), |
| 148 | + } |
| 149 | +) |
| 150 | + |
| 151 | +dtype_nbits = EqualityMapping( |
| 152 | + { |
| 153 | + **{d: 8 for d in [xp.int8, xp.uint8]}, |
| 154 | + **{d: 16 for d in [xp.int16, xp.uint16]}, |
| 155 | + **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, |
| 156 | + **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, |
| 157 | + } |
| 158 | +) |
| 159 | + |
| 160 | + |
| 161 | +dtype_signed = EqualityMapping( |
| 162 | + { |
| 163 | + **{d: True for d in int_dtypes}, |
| 164 | + **{d: False for d in uint_dtypes}, |
| 165 | + } |
| 166 | +) |
119 | 167 |
|
120 | 168 |
|
121 | 169 | if isinstance(xp.asarray, _UndefinedStub): |
@@ -179,11 +227,13 @@ class MinMax(NamedTuple): |
179 | 227 | (xp.float32, xp.float64): xp.float64, |
180 | 228 | (xp.float64, xp.float64): xp.float64, |
181 | 229 | } |
182 | | -promotion_table = { |
183 | | - (xp.bool, xp.bool): xp.bool, |
184 | | - **_numeric_promotions, |
185 | | - **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, |
186 | | -} |
| 230 | +promotion_table = EqualityMapping( |
| 231 | + { |
| 232 | + (xp.bool, xp.bool): xp.bool, |
| 233 | + **_numeric_promotions, |
| 234 | + **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, |
| 235 | + } |
| 236 | +) |
187 | 237 |
|
188 | 238 |
|
189 | 239 | def result_type(*dtypes: DataType): |
@@ -405,42 +455,3 @@ def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str: |
405 | 455 | # i.e. dtype is bool, int, or float |
406 | 456 | f_types.append(type_.__name__) |
407 | 457 | return ", ".join(f_types) |
408 | | - |
409 | | - |
410 | | -class EqualityMapping(Mapping): |
411 | | - """ |
412 | | - Mapping that uses equality for indexing |
413 | | -
|
414 | | - Typical mappings (e.g. the built-in dict) use hashing for indexing. This |
415 | | - isn't ideal for the Array API, as no __hash__() method is specified for |
416 | | - dtype objects - but __eq__() is! |
417 | | -
|
418 | | - See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects |
419 | | - """ |
420 | | - def __init__(self, mapping: Mapping): |
421 | | - keys = list(mapping.keys()) |
422 | | - for i, key in enumerate(keys): |
423 | | - if not (key == key): # specifically checking __eq__, not __neq__ |
424 | | - raise ValueError("Key {key!r} does not have equality with itself") |
425 | | - other_keys = keys[:] |
426 | | - other_keys.pop(i) |
427 | | - for other_key in other_keys: |
428 | | - if key == other_key: |
429 | | - raise ValueError("Key {key!r} has equality with key {other_key!r}") |
430 | | - self._mapping = mapping |
431 | | - |
432 | | - def __getitem__(self, key): |
433 | | - for k, v in self._mapping.items(): |
434 | | - if key == k: |
435 | | - return v |
436 | | - else: |
437 | | - raise KeyError(f"{key!r} not found") |
438 | | - |
439 | | - def __iter__(self): |
440 | | - return iter(self._mapping) |
441 | | - |
442 | | - def __len__(self): |
443 | | - return len(self._mapping) |
444 | | - |
445 | | - def __repr__(self): |
446 | | - return f"EqualityMapping({self._mapping!r})" |
|
0 commit comments