@@ -1890,3 +1890,75 @@ def test_put_along_axis_uint64_indices():
18901890 dpt .put_along_axis (x , inds , dpt .asarray (2 , dtype = x .dtype ), axis = 1 )
18911891 expected = dpt .tile (dpt .asarray ([0 , 2 ], dtype = "i4" ), (2 , 5 ))
18921892 assert dpt .all (expected == x )
1893+
1894+
1895+ @pytest .mark .parametrize (
1896+ "data_dt" ,
1897+ _all_dtypes ,
1898+ )
1899+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
1900+ def test_take_out (data_dt , order ):
1901+ q = get_queue_or_skip ()
1902+ skip_if_dtype_not_supported (data_dt , q )
1903+
1904+ axis = 0
1905+ x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
1906+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
1907+ out_sh = x .shape [:axis ] + ind .shape + x .shape [axis + 1 :]
1908+ out = dpt .empty (out_sh , dtype = data_dt , sycl_queue = q )
1909+
1910+ expected = dpt .take (x , ind , axis = axis )
1911+
1912+ dpt .take (x , ind , axis = axis , out = out )
1913+
1914+ assert dpt .all (out == expected )
1915+
1916+
1917+ @pytest .mark .parametrize (
1918+ "data_dt" ,
1919+ _all_dtypes ,
1920+ )
1921+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
1922+ def test_take_out_overlap (data_dt , order ):
1923+ q = get_queue_or_skip ()
1924+ skip_if_dtype_not_supported (data_dt , q )
1925+
1926+ axis = 0
1927+ x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
1928+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
1929+ out = x [x .shape [axis ] - ind .shape [axis ] : x .shape [axis ], :]
1930+
1931+ expected = dpt .take (x , ind , axis = axis )
1932+
1933+ dpt .take (x , ind , axis = axis , out = out )
1934+
1935+ assert dpt .all (out == expected )
1936+ assert dpt .all (x [x .shape [0 ] - ind .shape [0 ] : x .shape [0 ], :] == out )
1937+
1938+
1939+ def test_take_out_errors ():
1940+ q1 = get_queue_or_skip ()
1941+ q2 = get_queue_or_skip ()
1942+
1943+ x = dpt .arange (10 , dtype = "i4" , sycl_queue = q1 )
1944+ ind = dpt .arange (2 , dtype = "i4" , sycl_queue = q1 )
1945+
1946+ with pytest .raises (TypeError ):
1947+ dpt .take (x , ind , out = dict ())
1948+
1949+ out_read_only = dpt .empty (ind .shape , dtype = x .dtype , sycl_queue = q1 )
1950+ out_read_only .flags ["W" ] = False
1951+ with pytest .raises (ValueError ):
1952+ dpt .take (x , ind , out = out_read_only )
1953+
1954+ out_bad_shape = dpt .empty (0 , dtype = x .dtype , sycl_queue = q1 )
1955+ with pytest .raises (ValueError ):
1956+ dpt .take (x , ind , out = out_bad_shape )
1957+
1958+ out_bad_dt = dpt .empty (ind .shape , dtype = "i8" , sycl_queue = q1 )
1959+ with pytest .raises (ValueError ):
1960+ dpt .take (x , ind , out = out_bad_dt )
1961+
1962+ out_bad_q = dpt .empty (ind .shape , dtype = x .dtype , sycl_queue = q2 )
1963+ with pytest .raises (dpctl .utils .ExecutionPlacementError ):
1964+ dpt .take (x , ind , out = out_bad_q )
0 commit comments