|
16 | 16 | import pytest |
17 | 17 | from hypothesis import assume, given |
18 | 18 | from hypothesis.strategies import (booleans, composite, none, tuples, integers, |
19 | | - shared, sampled_from, data, just) |
| 19 | + shared, sampled_from, one_of, data, just) |
20 | 20 | from ndindex import iter_indices |
21 | 21 |
|
22 | 22 | from .array_helpers import assert_exactly_equal, asarray |
|
29 | 29 | SQRT_MAX_ARRAY_SIZE, finite_matrices) |
30 | 30 | from . import dtype_helpers as dh |
31 | 31 | from . import pytest_helpers as ph |
32 | | -from . import shape_helpers as sh |
33 | 32 |
|
34 | 33 | from . import _array_module |
35 | 34 | from . import _array_module as xp |
@@ -162,26 +161,18 @@ def test_cross(x1_x2_kw): |
162 | 161 | assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype" |
163 | 162 | assert res.shape == shape, "cross() did not return the correct shape" |
164 | 163 |
|
165 | | - # cross is too different from other functions to use _test_stacks, and it |
166 | | - # is the only function that works the way it does, so it's not really |
167 | | - # worth generalizing _test_stacks to handle it. |
168 | | - a = axis if axis >= 0 else axis + len(shape) |
169 | | - for _idx in sh.ndindex(shape[:a] + shape[a+1:]): |
170 | | - idx = _idx[:a] + (slice(None),) + _idx[a:] |
171 | | - assert len(idx) == len(shape), "Invalid index. This indicates a bug in the test suite." |
172 | | - res_stack = res[idx] |
173 | | - x1_stack = x1[idx] |
174 | | - x2_stack = x2[idx] |
175 | | - assert x1_stack.shape == x2_stack.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite." |
176 | | - decomp_res_stack = linalg.cross(x1_stack, x2_stack) |
177 | | - assert_exactly_equal(res_stack, decomp_res_stack) |
178 | | - |
179 | | - exact_cross = asarray([ |
180 | | - x1_stack[1]*x2_stack[2] - x1_stack[2]*x2_stack[1], |
181 | | - x1_stack[2]*x2_stack[0] - x1_stack[0]*x2_stack[2], |
182 | | - x1_stack[0]*x2_stack[1] - x1_stack[1]*x2_stack[0], |
183 | | - ], dtype=res.dtype) |
184 | | - assert_exactly_equal(res_stack, exact_cross) |
| 164 | + def exact_cross(a, b): |
| 165 | + assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite." |
| 166 | + return asarray([ |
| 167 | + a[1]*b[2] - a[2]*b[1], |
| 168 | + a[2]*b[0] - a[0]*b[2], |
| 169 | + a[0]*b[1] - a[1]*b[0], |
| 170 | + ], dtype=res.dtype) |
| 171 | + |
| 172 | + # We don't want to pass in **kw here because that would pass axis to |
| 173 | + # cross() on a single stack, but the axis is not meaningful on unstacked |
| 174 | + # vectors. |
| 175 | + _test_stacks(linalg.cross, x1, x2, dims=1, matrix_axes=(axis,), res=res, true_val=exact_cross) |
185 | 176 |
|
186 | 177 | @pytest.mark.xp_extension('linalg') |
187 | 178 | @given( |
|
0 commit comments