55from hypothesis import strategies as st
66from hypothesis .control import assume
77
8- from .typing import Scalar , ScalarType , Shape
9-
108from . import _array_module as xp
119from . import dtype_helpers as dh
1210from . import hypothesis_helpers as hh
1311from . import pytest_helpers as ph
1412from . import shape_helpers as sh
1513from . import xps
14+ from .typing import Scalar , Shape
1615
1716
1817def assert_scalar_in_set (
1918 func_name : str ,
20- type_ : ScalarType ,
2119 idx : Shape ,
2220 out : Scalar ,
2321 set_ : Set [Scalar ],
@@ -62,13 +60,12 @@ def test_argsort(x, data):
6260 scalar_type = dh .get_scalar_type (x .dtype )
6361 for indices in sh .axes_ndindex (x .shape , axes ):
6462 elements = [scalar_type (x [idx ]) for idx in indices ]
65- orders = sorted (
66- range (len (elements )),
67- key = elements .__getitem__ ,
68- reverse = kw .get ("descending" , False ),
63+ orders = list (range (len (elements )))
64+ sorders = sorted (
65+ orders , key = elements .__getitem__ , reverse = kw .get ("descending" , False )
6966 )
7067 if kw .get ("stable" , True ):
71- for idx , o in zip (indices , orders ):
68+ for idx , o in zip (indices , sorders ):
7269 ph .assert_scalar_equals ("argsort" , int , idx , int (out [idx ]), o , ** kw )
7370 else :
7471 idx_elements = dict (zip (indices , elements ))
@@ -78,17 +75,17 @@ def test_argsort(x, data):
7875 element_orders [e ] = [
7976 idx_orders [idx ] for idx in indices if idx_elements [idx ] == e
8077 ]
81- for idx , e in zip ( indices , elements ):
82- o = int ( out [ idx ])
78+ selements = [ elements [ o ] for o in sorders ]
79+ for idx , e in zip ( indices , selements ):
8380 expected_orders = element_orders [e ]
81+ out_o = int (out [idx ])
8482 if len (expected_orders ) == 1 :
85- expected_order = expected_orders [0 ]
8683 ph .assert_scalar_equals (
87- "argsort" , int , idx , o , expected_order , ** kw
84+ "argsort" , int , idx , out_o , expected_orders [ 0 ] , ** kw
8885 )
8986 else :
9087 assert_scalar_in_set (
91- "argsort" , int , idx , o , set (expected_orders ), ** kw
88+ "argsort" , idx , out_o , set (expected_orders ), ** kw
9289 )
9390
9491
@@ -129,6 +126,7 @@ def test_sort(x, data):
129126 )
130127 for out_idx , o in zip (indices , orders ):
131128 x_idx = indices [o ]
129+ # TODO: error message when unstable should not imply just one idx
132130 ph .assert_0d_equals (
133131 "sort" ,
134132 f"x[{ x_idx } ]" ,
0 commit comments