|
16 | 16 | import pytest |
17 | 17 | from hypothesis import assume, given |
18 | 18 | from hypothesis.strategies import (booleans, composite, none, tuples, integers, |
19 | | - shared, sampled_from, data, just) |
| 19 | + shared, sampled_from, one_of, data, just) |
| 20 | +from ndindex import iter_indices |
20 | 21 |
|
21 | 22 | from .array_helpers import assert_exactly_equal, asarray, equal, zero, infinity |
22 | 23 | from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes, |
|
43 | 44 | # Standin strategy for not yet implemented tests |
44 | 45 | todo = none() |
45 | 46 |
|
46 | | -def _test_stacks(f, *args, res=None, dims=2, true_val=None, **kw): |
| 47 | +def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1), **kw): |
47 | 48 | """ |
48 | 49 | Test that f(*args, **kw) maps across stacks of matrices |
49 | 50 |
|
50 | | - dims is the number of dimensions f should have for a single n x m matrix |
51 | | - stack. |
| 51 | + dims is the number of dimensions f(*args) should have for a single n x m |
| 52 | + matrix stack. |
| 53 | +
|
| 54 | + matrix_axes are the axes along which matrices (or vectors) are stacked in |
| 55 | + the input. |
| 56 | +
|
| 57 | + true_val may be a function such that true_val(*x_stacks, **kw) gives the |
| 58 | + true value for f on a stack. |
| 59 | +
|
| 60 | + res should be the result of f(*args, **kw). It is computed if not passed |
| 61 | + in. |
52 | 62 |
|
53 | | - true_val may be a function such that true_val(*x_stacks) gives the true |
54 | | - value for f on a stack |
55 | 63 | """ |
56 | 64 | if res is None: |
57 | 65 | res = f(*args, **kw) |
58 | 66 |
|
59 | | - shape = args[0].shape if len(args) == 1 else broadcast_shapes(*[x.shape |
60 | | - for x in args]) |
61 | | - for _idx in sh.ndindex(shape[:-2]): |
62 | | - idx = _idx + (slice(None),)*dims |
63 | | - res_stack = res[idx] |
64 | | - x_stacks = [x[_idx + (...,)] for x in args] |
| 67 | + shapes = [x.shape for x in args] |
| 68 | + |
| 69 | + for (x_idxes, (res_idx,)) in zip( |
| 70 | + iter_indices(*shapes, skip_axes=matrix_axes), |
| 71 | + iter_indices(res.shape, skip_axes=tuple(range(-dims, 0)))): |
| 72 | + x_idxes = [x_idx.raw for x_idx in x_idxes] |
| 73 | + res_idx = res_idx.raw |
| 74 | + # res should have `dims` slices in it. Cases where there are more than |
| 75 | + # `dims` slices are ambiguous, but that should only occur in cases |
| 76 | + # where axes = (-2, -1). |
| 77 | + # res_idx2 = [] |
| 78 | + # d = dims |
| 79 | + # for i in res_idx: |
| 80 | + # if isinstance(i, slice): |
| 81 | + # if d: |
| 82 | + # res_idx2.append(i) |
| 83 | + # d -= 1 |
| 84 | + # else: |
| 85 | + # res_idx2.append(i) |
| 86 | + # res_idx2 = tuple(res_idx2) |
| 87 | + |
| 88 | + res_stack = res[res_idx] |
| 89 | + x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)] |
65 | 90 | decomp_res_stack = f(*x_stacks, **kw) |
66 | 91 | assert_exactly_equal(res_stack, decomp_res_stack) |
67 | 92 | if true_val: |
|
0 commit comments