|
15 | 15 | from .typing import DataType, Scalar, ScalarType, Shape |
16 | 16 |
|
17 | 17 |
|
18 | | -def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]: |
19 | | - axes_strats = [st.none()] |
20 | | - if ndim != 0: |
21 | | - axes_strats.append(st.integers(-ndim, ndim - 1)) |
22 | | - axes_strats.append(xps.valid_tuple_axes(ndim)) |
23 | | - return st.one_of(axes_strats) |
24 | | - |
25 | | - |
26 | 18 | def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: |
27 | 19 | dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] |
28 | 20 | return st.none() | st.sampled_from(dtypes) |
@@ -108,7 +100,7 @@ def assert_equals( |
108 | 100 | data=st.data(), |
109 | 101 | ) |
110 | 102 | def test_max(x, data): |
111 | | - kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") |
| 103 | + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") |
112 | 104 |
|
113 | 105 | out = xp.max(x, **kw) |
114 | 106 |
|
@@ -137,7 +129,7 @@ def test_max(x, data): |
137 | 129 | data=st.data(), |
138 | 130 | ) |
139 | 131 | def test_mean(x, data): |
140 | | - kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") |
| 132 | + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") |
141 | 133 |
|
142 | 134 | out = xp.mean(x, **kw) |
143 | 135 |
|
@@ -166,7 +158,7 @@ def test_mean(x, data): |
166 | 158 | data=st.data(), |
167 | 159 | ) |
168 | 160 | def test_min(x, data): |
169 | | - kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") |
| 161 | + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") |
170 | 162 |
|
171 | 163 | out = xp.min(x, **kw) |
172 | 164 |
|
@@ -197,7 +189,7 @@ def test_min(x, data): |
197 | 189 | def test_prod(x, data): |
198 | 190 | kw = data.draw( |
199 | 191 | hh.kwargs( |
200 | | - axis=axes(x.ndim), |
| 192 | + axis=hh.axes(x.ndim), |
201 | 193 | dtype=kwarg_dtypes(x.dtype), |
202 | 194 | keepdims=st.booleans(), |
203 | 195 | ), |
@@ -258,7 +250,7 @@ def test_prod(x, data): |
258 | 250 | data=st.data(), |
259 | 251 | ) |
260 | 252 | def test_std(x, data): |
261 | | - axis = data.draw(axes(x.ndim), label="axis") |
| 253 | + axis = data.draw(hh.axes(x.ndim), label="axis") |
262 | 254 | _axes = normalise_axis(axis, x.ndim) |
263 | 255 | N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) |
264 | 256 | correction = data.draw( |
@@ -295,7 +287,7 @@ def test_std(x, data): |
295 | 287 | def test_sum(x, data): |
296 | 288 | kw = data.draw( |
297 | 289 | hh.kwargs( |
298 | | - axis=axes(x.ndim), |
| 290 | + axis=hh.axes(x.ndim), |
299 | 291 | dtype=kwarg_dtypes(x.dtype), |
300 | 292 | keepdims=st.booleans(), |
301 | 293 | ), |
@@ -356,7 +348,7 @@ def test_sum(x, data): |
356 | 348 | data=st.data(), |
357 | 349 | ) |
358 | 350 | def test_var(x, data): |
359 | | - axis = data.draw(axes(x.ndim), label="axis") |
| 351 | + axis = data.draw(hh.axes(x.ndim), label="axis") |
360 | 352 | _axes = normalise_axis(axis, x.ndim) |
361 | 353 | N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) |
362 | 354 | correction = data.draw( |
|
0 commit comments