@@ -322,48 +322,68 @@ def test_choose():
322322 assert_array_equal (expected , result )
323323
324324
325- @pytest .mark .parametrize ("arr_dtype" , get_all_dtypes (no_bool = True ))
326- @pytest .mark .parametrize ("offset" , [0 , 1 ], ids = ["0" , "1" ])
327- @pytest .mark .parametrize (
328- "array" ,
329- [
330- [[0 , 0 ], [0 , 0 ]],
331- [[1 , 2 ], [1 , 2 ]],
332- [[1 , 2 ], [3 , 4 ]],
333- [[0 , 1 , 2 ], [3 , 4 , 5 ], [6 , 7 , 8 ]],
334- [[0 , 1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 , 9 ]],
335- [[[1 , 2 ], [3 , 4 ]], [[1 , 2 ], [2 , 1 ]], [[1 , 3 ], [3 , 1 ]]],
336- [
337- [[[1 , 2 ], [3 , 4 ]], [[1 , 2 ], [2 , 1 ]]],
338- [[[1 , 3 ], [3 , 1 ]], [[0 , 1 ], [1 , 3 ]]],
339- ],
340- [
341- [[[1 , 2 , 3 ], [3 , 4 , 5 ]], [[1 , 2 , 3 ], [2 , 1 , 0 ]]],
342- [[[1 , 3 , 5 ], [3 , 1 , 0 ]], [[0 , 1 , 2 ], [1 , 3 , 4 ]]],
325+ class TestDiagonal :
326+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
327+ @pytest .mark .parametrize ("offset" , [- 3 , - 1 , 0 , 1 , 3 ])
328+ @pytest .mark .parametrize (
329+ "shape" ,
330+ [(2 , 2 ), (3 , 3 ), (2 , 5 ), (3 , 2 , 2 ), (2 , 2 , 2 , 2 ), (2 , 2 , 2 , 3 )],
331+ ids = [
332+ "(2,2)" ,
333+ "(3,3)" ,
334+ "(2,5)" ,
335+ "(3,2,2)" ,
336+ "(2,2,2,2)" ,
337+ "(2,2,2,3)" ,
343338 ],
339+ )
340+ def test_diagonal_offset (self , shape , dtype , offset ):
341+ a = numpy .arange (numpy .prod (shape ), dtype = dtype ).reshape (shape )
342+ a_dp = dpnp .array (a )
343+ expected = numpy .diagonal (a , offset )
344+ result = dpnp .diagonal (a_dp , offset )
345+ assert_array_equal (expected , result )
346+
347+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
348+ @pytest .mark .parametrize (
349+ "shape, axis_pairs" ,
344350 [
345- [[[1 , 2 , 3 ], [4 , 5 , 6 ]], [[7 , 8 , 9 ], [10 , 11 , 12 ]]],
346- [[[13 , 14 , 15 ], [16 , 17 , 18 ]], [[19 , 20 , 21 ], [22 , 23 , 24 ]]],
351+ ((3 , 4 ), [(0 , 1 ), (1 , 0 )]),
352+ ((3 , 4 , 5 ), [(0 , 1 ), (1 , 2 ), (0 , 2 )]),
353+ ((4 , 3 , 5 , 2 ), [(0 , 1 ), (1 , 2 ), (2 , 3 ), (0 , 3 )]),
347354 ],
348- ],
349- ids = [
350- "[[0, 0], [0, 0]]" ,
351- "[[1, 2], [1, 2]]" ,
352- "[[1, 2], [3, 4]]" ,
353- "[[0, 1, 2], [3, 4, 5], [6, 7, 8]]" ,
354- "[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]" ,
355- "[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]" ,
356- "[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]]" ,
357- "[[[[1, 2, 3], [3, 4, 5]], [[1, 2, 3], [2, 1, 0]]], [[[1, 3, 5], [3, 1, 0]], [[0, 1, 2], [1, 3, 4]]]]" ,
358- "[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]]]" ,
359- ],
360- )
361- def test_diagonal (array , arr_dtype , offset ):
362- a = numpy .array (array , dtype = arr_dtype )
363- ia = dpnp .array (a )
364- expected = numpy .diagonal (a , offset )
365- result = dpnp .diagonal (ia , offset )
366- assert_array_equal (expected , result )
355+ )
356+ def test_diagonal_axes (self , shape , axis_pairs , dtype ):
357+ a = numpy .arange (numpy .prod (shape ), dtype = dtype ).reshape (shape )
358+ a_dp = dpnp .array (a )
359+ for axis1 , axis2 in axis_pairs :
360+ expected = numpy .diagonal (a , axis1 = axis1 , axis2 = axis2 )
361+ result = dpnp .diagonal (a_dp , axis1 = axis1 , axis2 = axis2 )
362+ assert_array_equal (expected , result )
363+
364+ def test_diagonal_errors (self ):
365+ a = dpnp .arange (12 ).reshape (3 , 4 )
366+
367+ # unsupported type
368+ a_np = dpnp .asnumpy (a )
369+ assert_raises (TypeError , dpnp .diagonal , a_np )
370+
371+ # a.ndim < 2
372+ a_ndim_1 = a .flatten ()
373+ assert_raises (ValueError , dpnp .diagonal , a_ndim_1 )
374+
375+ # unsupported type `offset`
376+ assert_raises (TypeError , dpnp .diagonal , a , offset = 1.0 )
377+ assert_raises (TypeError , dpnp .diagonal , a , offset = [0 ])
378+
379+ # axes are out of bounds
380+ assert_raises (numpy .AxisError , a .diagonal , axis1 = 0 , axis2 = 5 )
381+ assert_raises (numpy .AxisError , a .diagonal , axis1 = 5 , axis2 = 0 )
382+ assert_raises (numpy .AxisError , a .diagonal , axis1 = 5 , axis2 = 5 )
383+
384+ # same axes
385+ assert_raises (ValueError , a .diagonal , axis1 = 1 , axis2 = 1 )
386+ assert_raises (ValueError , a .diagonal , axis1 = 1 , axis2 = - 1 )
367387
368388
369389@pytest .mark .parametrize ("arr_dtype" , get_all_dtypes ())
0 commit comments