88from . import shape_helpers as sh
99from . import xps
1010from .algos import broadcast_shapes
11- from .test_manipulation_functions import assert_equals as assert_equals_
12- from .test_statistical_functions import assert_equals , assert_keepdimable_shape
13- from .typing import DataType
14-
15-
16- def assert_default_index (func_name : str , dtype : DataType , repr_name = "out.dtype" ):
17- f_dtype = dh .dtype_to_name [dtype ]
18- msg = (
19- f"{ repr_name } ={ f_dtype } , should be the default index dtype, "
20- f"which is either int32 or int64 [{ func_name } ()]"
21- )
22- assert dtype in (xp .int32 , xp .int64 ), msg
2311
2412
2513@given (
@@ -41,9 +29,9 @@ def test_argmax(x, data):
4129
4230 out = xp .argmax (x , ** kw )
4331
44- assert_default_index ("argmax" , out .dtype )
32+ ph . assert_default_index ("argmax" , out .dtype )
4533 axes = sh .normalise_axis (kw .get ("axis" , None ), x .ndim )
46- assert_keepdimable_shape (
34+ ph . assert_keepdimable_shape (
4735 "argmax" , out .shape , x .shape , axes , kw .get ("keepdims" , False ), ** kw
4836 )
4937 scalar_type = dh .get_scalar_type (x .dtype )
@@ -54,7 +42,7 @@ def test_argmax(x, data):
5442 s = scalar_type (x [idx ])
5543 elements .append (s )
5644 expected = max (range (len (elements )), key = elements .__getitem__ )
57- assert_equals ("argmax" , int , out_idx , max_i , expected )
45+ ph . assert_scalar_equals ("argmax" , int , out_idx , max_i , expected )
5846
5947
6048@given (
@@ -76,9 +64,9 @@ def test_argmin(x, data):
7664
7765 out = xp .argmin (x , ** kw )
7866
79- assert_default_index ("argmin" , out .dtype )
67+ ph . assert_default_index ("argmin" , out .dtype )
8068 axes = sh .normalise_axis (kw .get ("axis" , None ), x .ndim )
81- assert_keepdimable_shape (
69+ ph . assert_keepdimable_shape (
8270 "argmin" , out .shape , x .shape , axes , kw .get ("keepdims" , False ), ** kw
8371 )
8472 scalar_type = dh .get_scalar_type (x .dtype )
@@ -89,7 +77,7 @@ def test_argmin(x, data):
8977 s = scalar_type (x [idx ])
9078 elements .append (s )
9179 expected = min (range (len (elements )), key = elements .__getitem__ )
92- assert_equals ("argmin" , int , out_idx , min_i , expected )
80+ ph . assert_scalar_equals ("argmin" , int , out_idx , min_i , expected )
9381
9482
9583# TODO: skip if opted out
@@ -106,7 +94,7 @@ def test_nonzero(x):
10694 assert (
10795 out [i ].size == size
10896 ), f"out[{ i } ].size={ x .size } , but should be out[0].size={ size } "
109- assert_default_index ("nonzero" , out [i ].dtype , repr_name = f"out[{ i } ].dtype" )
97+ ph . assert_default_index ("nonzero" , out [i ].dtype , repr_name = f"out[{ i } ].dtype" )
11098 indices = []
11199 if x .dtype == xp .bool :
112100 for idx in sh .ndindex (x .shape ):
@@ -151,6 +139,10 @@ def test_where(shapes, dtypes, data):
151139 _x2 = xp .broadcast_to (x2 , shape )
152140 for idx in sh .ndindex (shape ):
153141 if _cond [idx ]:
154- assert_equals_ ("where" , f"_x1[{ idx } ]" , _x1 [idx ], f"out[{ idx } ]" , out [idx ])
142+ ph .assert_0d_equals (
143+ "where" , f"_x1[{ idx } ]" , _x1 [idx ], f"out[{ idx } ]" , out [idx ]
144+ )
155145 else :
156- assert_equals_ ("where" , f"_x2[{ idx } ]" , _x2 [idx ], f"out[{ idx } ]" , out [idx ])
146+ ph .assert_0d_equals (
147+ "where" , f"_x2[{ idx } ]" , _x2 [idx ], f"out[{ idx } ]" , out [idx ]
148+ )
0 commit comments