|
6 | 6 |
|
7 | 7 | if TYPE_CHECKING: |
8 | 8 | from typing import Optional, Union, Tuple, List |
9 | | - from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info |
| 9 | + from ._typing import device, DefaultDataTypes, DataTypes, Capabilities |
10 | 10 |
|
11 | 11 | from ._array_object import ALL_DEVICES, CPU_DEVICE |
12 | 12 | from ._flags import get_array_api_strict_flags, requires_api_version |
13 | 13 | from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 |
14 | 14 |
|
15 | 15 | @requires_api_version('2023.12') |
16 | | -def __array_namespace_info__() -> Info: |
17 | | - import array_api_strict._info |
18 | | - return array_api_strict._info |
19 | | - |
20 | | -@requires_api_version('2023.12') |
21 | | -def capabilities() -> Capabilities: |
22 | | - flags = get_array_api_strict_flags() |
23 | | - res = {"boolean indexing": flags['boolean_indexing'], |
24 | | - "data-dependent shapes": flags['data_dependent_shapes'], |
25 | | - } |
26 | | - if flags['api_version'] >= '2024.12': |
27 | | - # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will |
28 | | - # drop support for NumPy 1 but for now, just compute the number |
29 | | - # directly |
30 | | - for i in range(1, 100): |
31 | | - try: |
32 | | - np.zeros((1,)*i) |
33 | | - except ValueError: |
34 | | - maxdims = i - 1 |
35 | | - break |
36 | | - else: |
37 | | - raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)") |
38 | | - res['max dimensions'] = maxdims |
39 | | - return res |
40 | | - |
41 | | -@requires_api_version('2023.12') |
42 | | -def default_device() -> device: |
43 | | - return CPU_DEVICE |
| 16 | +class __array_namespace_info__: |
| 17 | + @requires_api_version('2023.12') |
| 18 | + def capabilities(self) -> Capabilities: |
| 19 | + flags = get_array_api_strict_flags() |
| 20 | + res = {"boolean indexing": flags['boolean_indexing'], |
| 21 | + "data-dependent shapes": flags['data_dependent_shapes'], |
| 22 | + } |
| 23 | + if flags['api_version'] >= '2024.12': |
| 24 | + # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will |
| 25 | + # drop support for NumPy 1 but for now, just compute the number |
| 26 | + # directly |
| 27 | + for i in range(1, 100): |
| 28 | + try: |
| 29 | + np.zeros((1,)*i) |
| 30 | + except ValueError: |
| 31 | + maxdims = i - 1 |
| 32 | + break |
| 33 | + else: |
| 34 | + raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)") |
| 35 | + res['max dimensions'] = maxdims |
| 36 | + return res |
44 | 37 |
|
45 | | -@requires_api_version('2023.12') |
46 | | -def default_dtypes( |
47 | | - *, |
48 | | - device: Optional[device] = None, |
49 | | -) -> DefaultDataTypes: |
50 | | - return { |
51 | | - "real floating": float64, |
52 | | - "complex floating": complex128, |
53 | | - "integral": int64, |
54 | | - "indexing": int64, |
55 | | - } |
| 38 | + @requires_api_version('2023.12') |
| 39 | + def default_device(self) -> device: |
| 40 | + return CPU_DEVICE |
56 | 41 |
|
57 | | -@requires_api_version('2023.12') |
58 | | -def dtypes( |
59 | | - *, |
60 | | - device: Optional[device] = None, |
61 | | - kind: Optional[Union[str, Tuple[str, ...]]] = None, |
62 | | -) -> DataTypes: |
63 | | - if kind is None: |
| 42 | + @requires_api_version('2023.12') |
| 43 | + def default_dtypes( |
| 44 | + self, |
| 45 | + *, |
| 46 | + device: Optional[device] = None, |
| 47 | + ) -> DefaultDataTypes: |
64 | 48 | return { |
65 | | - "bool": bool, |
66 | | - "int8": int8, |
67 | | - "int16": int16, |
68 | | - "int32": int32, |
69 | | - "int64": int64, |
70 | | - "uint8": uint8, |
71 | | - "uint16": uint16, |
72 | | - "uint32": uint32, |
73 | | - "uint64": uint64, |
74 | | - "float32": float32, |
75 | | - "float64": float64, |
76 | | - "complex64": complex64, |
77 | | - "complex128": complex128, |
| 49 | + "real floating": float64, |
| 50 | + "complex floating": complex128, |
| 51 | + "integral": int64, |
| 52 | + "indexing": int64, |
78 | 53 | } |
79 | | - if kind == "bool": |
80 | | - return {"bool": bool} |
81 | | - if kind == "signed integer": |
82 | | - return { |
83 | | - "int8": int8, |
84 | | - "int16": int16, |
85 | | - "int32": int32, |
86 | | - "int64": int64, |
87 | | - } |
88 | | - if kind == "unsigned integer": |
89 | | - return { |
90 | | - "uint8": uint8, |
91 | | - "uint16": uint16, |
92 | | - "uint32": uint32, |
93 | | - "uint64": uint64, |
94 | | - } |
95 | | - if kind == "integral": |
96 | | - return { |
97 | | - "int8": int8, |
98 | | - "int16": int16, |
99 | | - "int32": int32, |
100 | | - "int64": int64, |
101 | | - "uint8": uint8, |
102 | | - "uint16": uint16, |
103 | | - "uint32": uint32, |
104 | | - "uint64": uint64, |
105 | | - } |
106 | | - if kind == "real floating": |
107 | | - return { |
108 | | - "float32": float32, |
109 | | - "float64": float64, |
110 | | - } |
111 | | - if kind == "complex floating": |
112 | | - return { |
113 | | - "complex64": complex64, |
114 | | - "complex128": complex128, |
115 | | - } |
116 | | - if kind == "numeric": |
117 | | - return { |
118 | | - "int8": int8, |
119 | | - "int16": int16, |
120 | | - "int32": int32, |
121 | | - "int64": int64, |
122 | | - "uint8": uint8, |
123 | | - "uint16": uint16, |
124 | | - "uint32": uint32, |
125 | | - "uint64": uint64, |
126 | | - "float32": float32, |
127 | | - "float64": float64, |
128 | | - "complex64": complex64, |
129 | | - "complex128": complex128, |
130 | | - } |
131 | | - if isinstance(kind, tuple): |
132 | | - res = {} |
133 | | - for k in kind: |
134 | | - res.update(dtypes(kind=k)) |
135 | | - return res |
136 | | - raise ValueError(f"unsupported kind: {kind!r}") |
137 | 54 |
|
138 | | -@requires_api_version('2023.12') |
139 | | -def devices() -> List[device]: |
140 | | - return list(ALL_DEVICES) |
| 55 | + @requires_api_version('2023.12') |
| 56 | + def dtypes( |
| 57 | + self, |
| 58 | + *, |
| 59 | + device: Optional[device] = None, |
| 60 | + kind: Optional[Union[str, Tuple[str, ...]]] = None, |
| 61 | + ) -> DataTypes: |
| 62 | + if kind is None: |
| 63 | + return { |
| 64 | + "bool": bool, |
| 65 | + "int8": int8, |
| 66 | + "int16": int16, |
| 67 | + "int32": int32, |
| 68 | + "int64": int64, |
| 69 | + "uint8": uint8, |
| 70 | + "uint16": uint16, |
| 71 | + "uint32": uint32, |
| 72 | + "uint64": uint64, |
| 73 | + "float32": float32, |
| 74 | + "float64": float64, |
| 75 | + "complex64": complex64, |
| 76 | + "complex128": complex128, |
| 77 | + } |
| 78 | + if kind == "bool": |
| 79 | + return {"bool": bool} |
| 80 | + if kind == "signed integer": |
| 81 | + return { |
| 82 | + "int8": int8, |
| 83 | + "int16": int16, |
| 84 | + "int32": int32, |
| 85 | + "int64": int64, |
| 86 | + } |
| 87 | + if kind == "unsigned integer": |
| 88 | + return { |
| 89 | + "uint8": uint8, |
| 90 | + "uint16": uint16, |
| 91 | + "uint32": uint32, |
| 92 | + "uint64": uint64, |
| 93 | + } |
| 94 | + if kind == "integral": |
| 95 | + return { |
| 96 | + "int8": int8, |
| 97 | + "int16": int16, |
| 98 | + "int32": int32, |
| 99 | + "int64": int64, |
| 100 | + "uint8": uint8, |
| 101 | + "uint16": uint16, |
| 102 | + "uint32": uint32, |
| 103 | + "uint64": uint64, |
| 104 | + } |
| 105 | + if kind == "real floating": |
| 106 | + return { |
| 107 | + "float32": float32, |
| 108 | + "float64": float64, |
| 109 | + } |
| 110 | + if kind == "complex floating": |
| 111 | + return { |
| 112 | + "complex64": complex64, |
| 113 | + "complex128": complex128, |
| 114 | + } |
| 115 | + if kind == "numeric": |
| 116 | + return { |
| 117 | + "int8": int8, |
| 118 | + "int16": int16, |
| 119 | + "int32": int32, |
| 120 | + "int64": int64, |
| 121 | + "uint8": uint8, |
| 122 | + "uint16": uint16, |
| 123 | + "uint32": uint32, |
| 124 | + "uint64": uint64, |
| 125 | + "float32": float32, |
| 126 | + "float64": float64, |
| 127 | + "complex64": complex64, |
| 128 | + "complex128": complex128, |
| 129 | + } |
| 130 | + if isinstance(kind, tuple): |
| 131 | + res = {} |
| 132 | + for k in kind: |
| 133 | + res.update(dtypes(kind=k)) |
| 134 | + return res |
| 135 | + raise ValueError(f"unsupported kind: {kind!r}") |
141 | 136 |
|
142 | | -__all__ = [ |
143 | | - "capabilities", |
144 | | - "default_device", |
145 | | - "default_dtypes", |
146 | | - "devices", |
147 | | - "dtypes", |
148 | | -] |
| 137 | + @requires_api_version('2023.12') |
| 138 | + def devices(self) -> List[device]: |
| 139 | + return list(ALL_DEVICES) |
0 commit comments