55from hypothesis import strategies as st
66from hypothesis .control import assume
77
8- from .typing import Scalar , ScalarType , Shape
9-
108from . import _array_module as xp
119from . import dtype_helpers as dh
1210from . import hypothesis_helpers as hh
1311from . import pytest_helpers as ph
1412from . import shape_helpers as sh
1513from . import xps
14+ from .typing import Scalar , Shape
1615
1716
1817def assert_scalar_in_set (
1918 func_name : str ,
20- type_ : ScalarType ,
2119 idx : Shape ,
2220 out : Scalar ,
2321 set_ : Set [Scalar ],
@@ -62,12 +60,13 @@ def test_argsort(x, data):
6260 scalar_type = dh .get_scalar_type (x .dtype )
6361 for indices in sh .axes_ndindex (x .shape , axes ):
6462 elements = [scalar_type (x [idx ]) for idx in indices ]
65- orders = sorted (range (len (elements )), key = elements .__getitem__ )
66- if kw .get ("descending" , False ):
67- orders = reversed (orders )
63+ orders = list (range (len (elements )))
64+ sorders = sorted (
65+ orders , key = elements .__getitem__ , reverse = kw .get ("descending" , False )
66+ )
6867 if kw .get ("stable" , True ):
69- for idx , o in zip (indices , orders ):
70- ph .assert_scalar_equals ("argsort" , int , idx , int (out [idx ]), o )
68+ for idx , o in zip (indices , sorders ):
69+ ph .assert_scalar_equals ("argsort" , int , idx , int (out [idx ]), o , ** kw )
7170 else :
7271 idx_elements = dict (zip (indices , elements ))
7372 idx_orders = dict (zip (indices , orders ))
@@ -76,17 +75,17 @@ def test_argsort(x, data):
7675 element_orders [e ] = [
7776 idx_orders [idx ] for idx in indices if idx_elements [idx ] == e
7877 ]
79- for idx , e in zip ( indices , elements ):
80- o = int ( out [ idx ])
78+ selements = [ elements [ o ] for o in sorders ]
79+ for idx , e in zip ( indices , selements ):
8180 expected_orders = element_orders [e ]
81+ out_o = int (out [idx ])
8282 if len (expected_orders ) == 1 :
83- expected_order = expected_orders [0 ]
8483 ph .assert_scalar_equals (
85- "argsort" , int , idx , o , expected_order , ** kw
84+ "argsort" , int , idx , out_o , expected_orders [ 0 ] , ** kw
8685 )
8786 else :
8887 assert_scalar_in_set (
89- "argsort" , int , idx , o , set (expected_orders ), ** kw
88+ "argsort" , idx , out_o , set (expected_orders ), ** kw
9089 )
9190
9291
@@ -127,6 +126,7 @@ def test_sort(x, data):
127126 )
128127 for out_idx , o in zip (indices , orders ):
129128 x_idx = indices [o ]
129+ # TODO: error message when unstable should not imply just one idx
130130 ph .assert_0d_equals (
131131 "sort" ,
132132 f"x[{ x_idx } ]" ,
0 commit comments