3737
3838pytestmark = pytest .mark .ci
3939
40-
41-
4240# Standin strategy for not yet implemented tests
4341todo = none ()
4442
45- def _test_stacks (f , * args , res = None , dims = 2 , true_val = None , matrix_axes = (- 2 , - 1 ),
43+ def _test_stacks (f , * args , res = None , dims = 2 , true_val = None ,
44+ matrix_axes = (- 2 , - 1 ),
4645 assert_equal = assert_exactly_equal , ** kw ):
4746 """
4847 Test that f(*args, **kw) maps across stacks of matrices
4948
50- dims is the number of dimensions f(*args) should have for a single n x m
51- matrix stack.
49+ dims is the number of dimensions f(*args, *kw ) should have for a single n
50+ x m matrix stack.
5251
5352 matrix_axes are the axes along which matrices (or vectors) are stacked in
5453 the input.
@@ -65,9 +64,13 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1)
6564
6665 shapes = [x .shape for x in args ]
6766
67+ # Assume the result is stacked along the last 'dims' axes of matrix_axes.
68+ # This holds for all the functions tested in this file
69+ res_axes = matrix_axes [::- 1 ][:dims ]
70+
6871 for (x_idxes , (res_idx ,)) in zip (
6972 iter_indices (* shapes , skip_axes = matrix_axes ),
70- iter_indices (res .shape , skip_axes = tuple ( range ( - dims , 0 )) )):
73+ iter_indices (res .shape , skip_axes = res_axes )):
7174 x_idxes = [x_idx .raw for x_idx in x_idxes ]
7275 res_idx = res_idx .raw
7376
0 commit comments