1+ import math
2+ from typing import Set
3+
14from hypothesis import given
25from hypothesis import strategies as st
36from hypothesis .control import assume
47
8+ from xptests .typing import Scalar , ScalarType , Shape
9+
510from . import _array_module as xp
611from . import dtype_helpers as dh
712from . import hypothesis_helpers as hh
1015from . import xps
1116
1217
18+ def assert_scalar_in_set (
19+ func_name : str ,
20+ type_ : ScalarType ,
21+ idx : Shape ,
22+ out : Scalar ,
23+ set_ : Set [Scalar ],
24+ / ,
25+ ** kw ,
26+ ):
27+ out_repr = "out" if idx == () else f"out[{ idx } ]"
28+ if math .isnan (out ):
29+ raise NotImplementedError ()
30+ msg = f"{ out_repr } ={ out } , but should be in { set_ } [{ func_name } ({ ph .fmt_kw (kw )} )]"
31+ assert out in set_ , msg
32+
33+
1334# TODO: Test with signed zeros and NaNs (and ignore them somehow)
1435@given (
1536 x = xps .arrays (
@@ -34,20 +55,39 @@ def test_argsort(x, data):
3455
3556 out = xp .argsort (x , ** kw )
3657
37- ph .assert_default_index ("sort " , out .dtype )
38- ph .assert_shape ("sort " , out .shape , x .shape , ** kw )
58+ ph .assert_default_index ("argsort " , out .dtype )
59+ ph .assert_shape ("argsort " , out .shape , x .shape , ** kw )
3960 axis = kw .get ("axis" , - 1 )
4061 axes = sh .normalise_axis (axis , x .ndim )
41- descending = kw .get ("descending" , False )
4262 scalar_type = dh .get_scalar_type (x .dtype )
4363 for indices in sh .axes_ndindex (x .shape , axes ):
4464 elements = [scalar_type (x [idx ]) for idx in indices ]
45- indices_order = sorted (range (len (indices )), key = elements .__getitem__ )
46- if descending :
47- # sorted(..., reverse=descending) doesn't always work
48- indices_order = reversed (indices_order )
49- for idx , o in zip (indices , indices_order ):
50- ph .assert_scalar_equals ("argsort" , int , idx , int (out [idx ]), o )
65+ orders = sorted (range (len (elements )), key = elements .__getitem__ )
66+ if kw .get ("descending" , False ):
67+ orders = reversed (orders )
68+ if kw .get ("stable" , True ):
69+ for idx , o in zip (indices , orders ):
70+ ph .assert_scalar_equals ("argsort" , int , idx , int (out [idx ]), o )
71+ else :
72+ idx_elements = dict (zip (indices , elements ))
73+ idx_orders = dict (zip (indices , orders ))
74+ element_orders = {}
75+ for e in set (elements ):
76+ element_orders [e ] = [
77+ idx_orders [idx ] for idx in indices if idx_elements [idx ] == e
78+ ]
79+ for idx , e in zip (indices , elements ):
80+ o = int (out [idx ])
81+ expected_orders = element_orders [e ]
82+ if len (expected_orders ) == 1 :
83+ expected_order = expected_orders [0 ]
84+ ph .assert_scalar_equals (
85+ "argsort" , int , idx , o , expected_order , ** kw
86+ )
87+ else :
88+ assert_scalar_in_set (
89+ "argsort" , int , idx , o , set (expected_orders ), ** kw
90+ )
5191
5292
5393# TODO: Test with signed zeros and NaNs (and ignore them somehow)
@@ -78,15 +118,15 @@ def test_sort(x, data):
78118 ph .assert_shape ("sort" , out .shape , x .shape , ** kw )
79119 axis = kw .get ("axis" , - 1 )
80120 axes = sh .normalise_axis (axis , x .ndim )
81- descending = kw .get ("descending" , False )
82121 scalar_type = dh .get_scalar_type (x .dtype )
83122 for indices in sh .axes_ndindex (x .shape , axes ):
84123 elements = [scalar_type (x [idx ]) for idx in indices ]
85- indices_order = sorted (
86- range (len (indices )), key = elements .__getitem__ , reverse = descending
124+ size = len (elements )
125+ orders = sorted (
126+ range (size ), key = elements .__getitem__ , reverse = kw .get ("descending" , False )
87127 )
88- x_indices = [ indices [ o ] for o in indices_order ]
89- for out_idx , x_idx in zip ( indices , x_indices ):
128+ for out_idx , o in zip ( indices , orders ):
129+ x_idx = indices [ o ]
90130 ph .assert_0d_equals (
91131 "sort" ,
92132 f"x[{ x_idx } ]" ,
0 commit comments