@@ -625,7 +625,7 @@ def test_put_0d_val(data_dt):
625625 skip_if_dtype_not_supported (data_dt , q )
626626
627627 x = dpt .arange (5 , dtype = data_dt , sycl_queue = q )
628- ind = dpt .asarray ([0 ], dtype = np . intp , sycl_queue = q )
628+ ind = dpt .asarray ([0 ], dtype = "i8" , sycl_queue = q )
629629 val = dpt .asarray (2 , dtype = x .dtype , sycl_queue = q )
630630 x [ind ] = val
631631 assert_array_equal (np .asarray (2 , dtype = data_dt ), dpt .asnumpy (x [0 ]))
@@ -644,7 +644,7 @@ def test_take_0d_data(data_dt):
644644 skip_if_dtype_not_supported (data_dt , q )
645645
646646 x = dpt .asarray (0 , dtype = data_dt , sycl_queue = q )
647- ind = dpt .arange (5 , dtype = np . intp , sycl_queue = q )
647+ ind = dpt .arange (5 , dtype = "i8" , sycl_queue = q )
648648
649649 y = dpt .take (x , ind )
650650 assert (
@@ -662,7 +662,7 @@ def test_put_0d_data(data_dt):
662662 skip_if_dtype_not_supported (data_dt , q )
663663
664664 x = dpt .asarray (0 , dtype = data_dt , sycl_queue = q )
665- ind = dpt .arange (5 , dtype = np . intp , sycl_queue = q )
665+ ind = dpt .arange (5 , dtype = "i8" , sycl_queue = q )
666666 val = dpt .asarray (2 , dtype = data_dt , sycl_queue = q )
667667
668668 dpt .put (x , ind , val , axis = 0 )
@@ -710,7 +710,7 @@ def test_take_strided_1d_source(data_dt):
710710 skip_if_dtype_not_supported (data_dt , q )
711711
712712 x = dpt .arange (27 , dtype = data_dt , sycl_queue = q )
713- ind = dpt .arange (4 , 9 , dtype = np . intp , sycl_queue = q )
713+ ind = dpt .arange (4 , 9 , dtype = "i8" , sycl_queue = q )
714714
715715 x_np = dpt .asnumpy (x )
716716 ind_np = dpt .asnumpy (ind )
@@ -748,7 +748,7 @@ def test_take_strided(data_dt, order):
748748 skip_if_dtype_not_supported (data_dt , q )
749749
750750 x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
751- ind = dpt .arange (2 , dtype = np . intp , sycl_queue = q )
751+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
752752
753753 x_np = dpt .asnumpy (x )
754754 ind_np = dpt .asnumpy (ind )
@@ -781,7 +781,7 @@ def test_take_strided_1d_indices(ind_dt):
781781 ind = dpt .arange (12 , 24 , dtype = ind_dt , sycl_queue = q )
782782
783783 x_np = dpt .asnumpy (x )
784- ind_np = dpt .asnumpy (ind ).astype (np . intp )
784+ ind_np = dpt .asnumpy (ind ).astype ("i8" )
785785
786786 for s in (
787787 slice (None , None , 2 ),
@@ -820,7 +820,7 @@ def test_take_strided_indices(ind_dt, order):
820820 )
821821
822822 x_np = dpt .asnumpy (x )
823- ind_np = dpt .asnumpy (ind ).astype (np . intp )
823+ ind_np = dpt .asnumpy (ind ).astype ("i8" )
824824
825825 for s in (
826826 slice (None , None , 2 ),
@@ -845,7 +845,7 @@ def test_put_strided_1d_destination(data_dt, order):
845845 skip_if_dtype_not_supported (data_dt , q )
846846
847847 x = dpt .arange (27 , dtype = data_dt , sycl_queue = q )
848- ind = dpt .arange (4 , 9 , dtype = np . intp , sycl_queue = q )
848+ ind = dpt .arange (4 , 9 , dtype = "i8" , sycl_queue = q )
849849 val = dpt .asarray (9 , dtype = x .dtype , sycl_queue = q )
850850
851851 x_np = dpt .asnumpy (x )
@@ -875,7 +875,7 @@ def test_put_strided_destination(data_dt, order):
875875 skip_if_dtype_not_supported (data_dt , q )
876876
877877 x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
878- ind = dpt .arange (2 , dtype = np . intp , sycl_queue = q )
878+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
879879 val = dpt .asarray (9 , dtype = x .dtype , sycl_queue = q )
880880
881881 x_np = dpt .asnumpy (x )
@@ -924,7 +924,7 @@ def test_put_strided_1d_indices(ind_dt):
924924 val = dpt .asarray (- 1 , dtype = x .dtype , sycl_queue = q )
925925
926926 x_np = dpt .asnumpy (x )
927- ind_np = dpt .asnumpy (ind ).astype (np . intp )
927+ ind_np = dpt .asnumpy (ind ).astype ("i8" )
928928 val_np = dpt .asnumpy (val )
929929
930930 for s in (
@@ -955,7 +955,7 @@ def test_put_strided_indices(ind_dt, order):
955955 val = dpt .asarray (- 1 , sycl_queue = q , dtype = x .dtype )
956956
957957 x_np = dpt .asnumpy (x )
958- ind_np = dpt .asnumpy (ind ).astype (np . intp )
958+ ind_np = dpt .asnumpy (ind ).astype ("i8" )
959959 val_np = dpt .asnumpy (val )
960960
961961 for s in (
@@ -982,15 +982,15 @@ def test_integer_indexing_modes():
982982 x_np = dpt .asnumpy (x )
983983
984984 # wrapping negative indices
985- ind = dpt .asarray ([- 4 , - 3 , 0 , 2 , 4 ], dtype = np . intp , sycl_queue = q )
985+ ind = dpt .asarray ([- 4 , - 3 , 0 , 2 , 4 ], dtype = "i8" , sycl_queue = q )
986986
987987 res = dpt .take (x , ind , mode = "wrap" )
988988 expected_arr = np .take (x_np , dpt .asnumpy (ind ), mode = "raise" )
989989
990990 assert (dpt .asnumpy (res ) == expected_arr ).all ()
991991
992992 # clipping to 0 (disabling negative indices)
993- ind = dpt .asarray ([- 6 , - 3 , 0 , 2 , 6 ], dtype = np . intp , sycl_queue = q )
993+ ind = dpt .asarray ([- 6 , - 3 , 0 , 2 , 6 ], dtype = "i8" , sycl_queue = q )
994994
995995 res = dpt .take (x , ind , mode = "clip" )
996996 expected_arr = np .take (x_np , dpt .asnumpy (ind ), mode = "clip" )
@@ -1002,7 +1002,7 @@ def test_take_arg_validation():
10021002 q = get_queue_or_skip ()
10031003
10041004 x = dpt .arange (4 , dtype = "i4" , sycl_queue = q )
1005- ind0 = dpt .arange (4 , dtype = np . intp , sycl_queue = q )
1005+ ind0 = dpt .arange (4 , dtype = "i8" , sycl_queue = q )
10061006 ind1 = dpt .arange (2.0 , dtype = "f" , sycl_queue = q )
10071007
10081008 with pytest .raises (TypeError ):
@@ -1034,7 +1034,7 @@ def test_put_arg_validation():
10341034 q = get_queue_or_skip ()
10351035
10361036 x = dpt .arange (4 , dtype = "i4" , sycl_queue = q )
1037- ind0 = dpt .arange (4 , dtype = np . intp , sycl_queue = q )
1037+ ind0 = dpt .arange (4 , dtype = "i8" , sycl_queue = q )
10381038 ind1 = dpt .arange (2.0 , dtype = "f" , sycl_queue = q )
10391039 val = dpt .asarray (2 , dtype = x .dtype , sycl_queue = q )
10401040
@@ -1890,3 +1890,69 @@ def test_put_along_axis_uint64_indices():
18901890 dpt .put_along_axis (x , inds , dpt .asarray (2 , dtype = x .dtype ), axis = 1 )
18911891 expected = dpt .tile (dpt .asarray ([0 , 2 ], dtype = "i4" ), (2 , 5 ))
18921892 assert dpt .all (expected == x )
1893+
1894+
1895+ @pytest .mark .parametrize ("data_dt" , _all_dtypes )
1896+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
1897+ def test_take_out (data_dt , order ):
1898+ q = get_queue_or_skip ()
1899+ skip_if_dtype_not_supported (data_dt , q )
1900+
1901+ axis = 0
1902+ x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
1903+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
1904+ out_sh = x .shape [:axis ] + ind .shape + x .shape [axis + 1 :]
1905+ out = dpt .empty (out_sh , dtype = data_dt , sycl_queue = q )
1906+
1907+ expected = dpt .take (x , ind , axis = axis )
1908+
1909+ dpt .take (x , ind , axis = axis , out = out )
1910+
1911+ assert dpt .all (out == expected )
1912+
1913+
1914+ @pytest .mark .parametrize ("data_dt" , _all_dtypes )
1915+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
1916+ def test_take_out_overlap (data_dt , order ):
1917+ q = get_queue_or_skip ()
1918+ skip_if_dtype_not_supported (data_dt , q )
1919+
1920+ axis = 0
1921+ x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
1922+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
1923+ out = x [x .shape [axis ] - ind .shape [axis ] : x .shape [axis ], :]
1924+
1925+ expected = dpt .take (x , ind , axis = axis )
1926+
1927+ dpt .take (x , ind , axis = axis , out = out )
1928+
1929+ assert dpt .all (out == expected )
1930+ assert dpt .all (x [x .shape [0 ] - ind .shape [0 ] : x .shape [0 ], :] == out )
1931+
1932+
1933+ def test_take_out_errors ():
1934+ q1 = get_queue_or_skip ()
1935+ q2 = get_queue_or_skip ()
1936+
1937+ x = dpt .arange (10 , dtype = "i4" , sycl_queue = q1 )
1938+ ind = dpt .arange (2 , dtype = "i4" , sycl_queue = q1 )
1939+
1940+ with pytest .raises (TypeError ):
1941+ dpt .take (x , ind , out = dict ())
1942+
1943+ out_read_only = dpt .empty (ind .shape , dtype = x .dtype , sycl_queue = q1 )
1944+ out_read_only .flags ["W" ] = False
1945+ with pytest .raises (ValueError ):
1946+ dpt .take (x , ind , out = out_read_only )
1947+
1948+ out_bad_shape = dpt .empty (0 , dtype = x .dtype , sycl_queue = q1 )
1949+ with pytest .raises (ValueError ):
1950+ dpt .take (x , ind , out = out_bad_shape )
1951+
1952+ out_bad_dt = dpt .empty (ind .shape , dtype = "i8" , sycl_queue = q1 )
1953+ with pytest .raises (ValueError ):
1954+ dpt .take (x , ind , out = out_bad_dt )
1955+
1956+ out_bad_q = dpt .empty (ind .shape , dtype = x .dtype , sycl_queue = q2 )
1957+ with pytest .raises (dpctl .utils .ExecutionPlacementError ):
1958+ dpt .take (x , ind , out = out_bad_q )
0 commit comments