@@ -55,7 +55,7 @@ def _expected_largest_inds(inp, n, shift, k):
5555@pytest .mark .parametrize (
5656 "dtype" ,
5757 [
58- pytest . param ( "i1" , marks = pytest . mark . skip ( reason = "CPU bug" )) ,
58+ "i1" ,
5959 "u1" ,
6060 "i2" ,
6161 "u2" ,
@@ -74,8 +74,6 @@ def _expected_largest_inds(inp, n, shift, k):
7474def test_top_k_1d_largest (dtype , n ):
7575 q = get_queue_or_skip ()
7676 skip_if_dtype_not_supported (dtype , q )
77- if dtype == "i1" :
78- pytest .skip ()
7977
8078 shift , k = 734 , 5
8179 o = dpt .ones (n , dtype = dtype )
@@ -89,9 +87,9 @@ def test_top_k_1d_largest(dtype, n):
8987 assert s .values .shape == (k ,)
9088 assert s .values .dtype == inp .dtype
9189 assert s .indices .shape == (k ,)
92- assert dpt .all (s .indices == expected_inds )
9390 assert dpt .all (s .values == dpt .ones (k , dtype = dtype )), s .values
9491 assert dpt .all (s .values == inp [s .indices ]), s .indices
92+ assert dpt .all (s .indices == expected_inds ), (s .indices , expected_inds )
9593
9694
9795def _expected_smallest_inds (inp , n , shift , k ):
@@ -128,7 +126,7 @@ def _expected_smallest_inds(inp, n, shift, k):
128126@pytest .mark .parametrize (
129127 "dtype" ,
130128 [
131- pytest . param ( "i1" , marks = pytest . mark . skip ( reason = "CPU bug" )) ,
129+ "i1" ,
132130 "u1" ,
133131 "i2" ,
134132 "u2" ,
@@ -160,6 +158,158 @@ def test_top_k_1d_smallest(dtype, n):
160158 assert s .values .shape == (k ,)
161159 assert s .values .dtype == inp .dtype
162160 assert s .indices .shape == (k ,)
163- assert dpt .all (s .indices == expected_inds )
164161 assert dpt .all (s .values == dpt .zeros (k , dtype = dtype )), s .values
165162 assert dpt .all (s .values == inp [s .indices ]), s .indices
163+ assert dpt .all (s .indices == expected_inds ), (s .indices , expected_inds )
164+
165+
166+ @pytest .mark .parametrize (
167+ "dtype" ,
168+ [
169+ # skip short types to ensure that m*n can be represented
170+ # in the type
171+ "i4" ,
172+ "u4" ,
173+ "i8" ,
174+ "u8" ,
175+ "f2" ,
176+ "f4" ,
177+ "f8" ,
178+ "c8" ,
179+ "c16" ,
180+ ],
181+ )
182+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
183+ def test_top_k_2d_largest (dtype , n ):
184+ q = get_queue_or_skip ()
185+ skip_if_dtype_not_supported (dtype , q )
186+
187+ m , k = 8 , 3
188+ if dtype == "f2" and m * n > 2000 :
189+ pytest .skip (
190+ "f2 can not distinguish between large integers used in this test"
191+ )
192+
193+ x = dpt .reshape (dpt .arange (m * n , dtype = dtype ), (m , n ))
194+
195+ r = dpt .top_k (x , k , axis = 1 )
196+
197+ assert r .values .shape == (m , k )
198+ assert r .indices .shape == (m , k )
199+ expected_inds = dpt .reshape (dpt .arange (n , dtype = r .indices .dtype ), (1 , n ))[
200+ :, - k :
201+ ]
202+ assert expected_inds .shape == (1 , k )
203+ assert dpt .all (
204+ dpt .sort (r .indices , axis = 1 ) == dpt .sort (expected_inds , axis = 1 )
205+ ), (r .indices , expected_inds )
206+ expected_vals = x [:, - k :]
207+ assert dpt .all (
208+ dpt .sort (r .values , axis = 1 ) == dpt .sort (expected_vals , axis = 1 )
209+ )
210+
211+
212+ @pytest .mark .parametrize (
213+ "dtype" ,
214+ [
215+ # skip short types to ensure that m*n can be represented
216+ # in the type
217+ "i4" ,
218+ "u4" ,
219+ "i8" ,
220+ "u8" ,
221+ "f2" ,
222+ "f4" ,
223+ "f8" ,
224+ "c8" ,
225+ "c16" ,
226+ ],
227+ )
228+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
229+ def test_top_k_2d_smallest (dtype , n ):
230+ q = get_queue_or_skip ()
231+ skip_if_dtype_not_supported (dtype , q )
232+
233+ m , k = 8 , 3
234+ if dtype == "f2" and m * n > 2000 :
235+ pytest .skip (
236+ "f2 can not distinguish between large integers used in this test"
237+ )
238+
239+ x = dpt .reshape (dpt .arange (m * n , dtype = dtype ), (m , n ))
240+
241+ r = dpt .top_k (x , k , axis = 1 , mode = "smallest" )
242+
243+ assert r .values .shape == (m , k )
244+ assert r .indices .shape == (m , k )
245+ expected_inds = dpt .reshape (dpt .arange (n , dtype = r .indices .dtype ), (1 , n ))[
246+ :, :k
247+ ]
248+ assert dpt .all (
249+ dpt .sort (r .indices , axis = 1 ) == dpt .sort (expected_inds , axis = 1 )
250+ )
251+ assert dpt .all (dpt .sort (r .values , axis = 1 ) == dpt .sort (x [:, :k ], axis = 1 ))
252+
253+
254+ def test_top_k_0d ():
255+ get_queue_or_skip ()
256+
257+ a = dpt .ones (tuple (), dtype = "i4" )
258+ assert a .ndim == 0
259+ assert a .size == 1
260+
261+ r = dpt .top_k (a , 1 )
262+ assert r .values == a
263+ assert r .indices == dpt .zeros_like (a , dtype = r .indices .dtype )
264+
265+
266+ def test_top_k_noncontig ():
267+ get_queue_or_skip ()
268+
269+ a = dpt .arange (256 , dtype = dpt .int32 )[::2 ]
270+ r = dpt .top_k (a , 3 )
271+
272+ assert dpt .all (dpt .sort (r .values ) == dpt .asarray ([250 , 252 , 254 ])), r .values
273+ assert dpt .all (
274+ dpt .sort (r .indices ) == dpt .asarray ([125 , 126 , 127 ])
275+ ), r .indices
276+
277+
278+ def test_top_k_axis0 ():
279+ get_queue_or_skip ()
280+
281+ m , n , k = 128 , 8 , 3
282+ x = dpt .reshape (dpt .arange (m * n , dtype = dpt .int32 ), (m , n ))
283+
284+ r = dpt .top_k (x , k , axis = 0 , mode = "smallest" )
285+ assert r .values .shape == (k , n )
286+ assert r .indices .shape == (k , n )
287+ expected_inds = dpt .reshape (dpt .arange (m , dtype = r .indices .dtype ), (m , 1 ))[
288+ :k , :
289+ ]
290+ assert dpt .all (
291+ dpt .sort (r .indices , axis = 0 ) == dpt .sort (expected_inds , axis = 0 )
292+ )
293+ assert dpt .all (dpt .sort (r .values , axis = 0 ) == dpt .sort (x [:k , :], axis = 0 ))
294+
295+
296+ def test_top_k_validation ():
297+ get_queue_or_skip ()
298+ x = dpt .ones (10 , dtype = dpt .int64 )
299+ with pytest .raises (ValueError ):
300+ # k must be positive
301+ dpt .top_k (x , - 1 )
302+ with pytest .raises (TypeError ):
303+ # argument should be usm_ndarray
304+ dpt .top_k (list (), 2 )
305+ x2 = dpt .reshape (x , (2 , 5 ))
306+ with pytest .raises (ValueError ):
307+ # k must not exceed array dimension
308+ # along specified axis
309+ dpt .top_k (x2 , 100 , axis = 1 )
310+ with pytest .raises (ValueError ):
311+ # for 0d arrays, k must be 1
312+ dpt .top_k (x [0 ], 2 )
313+ with pytest .raises (ValueError ):
314+ # mode must be "largest", or "smallest"
315+ dpt .top_k (x , 2 , mode = "invalid" )
0 commit comments