|
22 | 22 | try: |
23 | 23 | # torch >=2.3 |
24 | 24 | _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64} |
25 | | - _HAS_LARGE_UINT = True |
26 | 25 | except AttributeError: |
27 | | - _HAS_LARGE_UINT = False |
| 26 | + pass |
| 27 | + |
28 | 28 |
|
29 | 29 | _array_api_dtypes = { |
30 | 30 | torch.bool, |
|
59 | 59 | (torch.float64, torch.complex128): torch.complex128, |
60 | 60 | } |
61 | 61 |
|
62 | | -if _HAS_LARGE_UINT: # torch >=2.3 |
63 | | - _promotion_table.update( |
64 | | - { |
65 | | - # uints |
66 | | - (torch.uint8, torch.uint16): torch.uint16, |
67 | | - (torch.uint8, torch.uint32): torch.uint32, |
68 | | - (torch.uint8, torch.uint64): torch.uint64, |
69 | | - (torch.uint16, torch.uint32): torch.uint32, |
70 | | - (torch.uint16, torch.uint64): torch.uint64, |
71 | | - (torch.uint32, torch.uint64): torch.uint64, |
72 | | - # ints and uints (mixed sign) |
73 | | - (torch.uint16, torch.int8): torch.int32, |
74 | | - (torch.uint16, torch.int16): torch.int32, |
75 | | - (torch.uint16, torch.int32): torch.int32, |
76 | | - (torch.uint16, torch.int64): torch.int64, |
77 | | - (torch.uint32, torch.int8): torch.int64, |
78 | | - (torch.uint32, torch.int16): torch.int64, |
79 | | - (torch.uint32, torch.int32): torch.int64, |
80 | | - (torch.uint32, torch.int64): torch.int64, |
81 | | - } |
82 | | - ) |
83 | | - |
84 | 62 | _promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()}) |
85 | 63 | _promotion_table.update({(a, a): a for a in _array_api_dtypes}) |
86 | 64 |
|
@@ -317,16 +295,10 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: |
317 | 295 | if dtype is not None: |
318 | 296 | return x.clone() if dtype == x.dtype else x.to(dtype) |
319 | 297 |
|
320 | | - if x.dtype in (torch.int8, torch.int16, torch.int32): |
321 | | - return x.to(torch.int64) |
322 | | - |
323 | | - if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32): |
324 | | - return x.to(torch.uint64) |
325 | | - |
326 | | - if x.dtype == torch.uint8: |
327 | | - # We can't upcast uint8 according to the spec because there is no |
328 | | - # torch.uint64, so at least upcast to int64 which is what prod does |
329 | | - # when axis=None. |
| 298 | + # We can't upcast uint8 according to the spec because there is no |
| 299 | + # torch.uint64, so at least upcast to int64 which is what prod does |
| 300 | + # when axis=None. |
| 301 | + if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32): |
330 | 302 | return x.to(torch.int64) |
331 | 303 |
|
332 | 304 | return x.clone() |
|
0 commit comments