File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -43,6 +43,7 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
4343 q = get_queue_or_skip ()
4444 skip_if_dtype_not_supported (arg_dtype , q )
4545
46+ # test reduction for C-contiguous input
4647 m = dpt .ones (100 , dtype = arg_dtype )
4748 r = dpt .sum (m )
4849
@@ -55,12 +56,20 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
5556 assert r .dtype .kind == "f"
5657 elif m .dtype .kind == "c" :
5758 assert r .dtype .kind == "c"
59+
5860 assert dpt .all (r == 100 )
5961
62+ # test reduction for strided input
6063 m = dpt .ones (200 , dtype = arg_dtype )[:1 :- 2 ]
6164 r = dpt .sum (m )
6265 assert dpt .all (r == 99 )
6366
67+ # test reduction for strided input which can be simplified
68+ # to contiguous computation
69+ m = dpt .ones (100 , dtype = arg_dtype )
70+ r = dpt .sum (dpt .flip (m ))
71+ assert dpt .all (r == 100 )
72+
6473
6574@pytest .mark .parametrize ("arg_dtype" , _all_dtypes )
6675@pytest .mark .parametrize ("out_dtype" , _all_dtypes [1 :])
You can’t perform that action at this time.
0 commit comments