3737
3838"""
3939
40+ # pylint: disable=protected-access
41+
4042import operator
4143
4244import dpctl .tensor as dpt
45+ import dpctl .tensor ._tensor_impl as ti
46+ import dpctl .utils as dpu
4347import numpy
48+ from dpctl .tensor ._copy_utils import _nonzero_impl
49+ from dpctl .tensor ._indexing_functions import _get_indexing_mode
4450from dpctl .tensor ._numpy_helper import normalize_axis_index
4551
4652import dpnp
5561
5662__all__ = [
5763 "choose" ,
64+ "compress" ,
5865 "diag_indices" ,
5966 "diag_indices_from" ,
6067 "diagonal" ,
@@ -155,6 +162,157 @@ def choose(x1, choices, out=None, mode="raise"):
155162 return call_origin (numpy .choose , x1 , choices , out , mode )
156163
157164
165+ def _take_index (x , inds , axis , q , usm_type , out = None , mode = 0 ):
166+ # arg validation assumed done by caller
167+ x_sh = x .shape
168+ axis_end = axis + 1
169+ if 0 in x_sh [axis :axis_end ] and inds .size != 0 :
170+ raise IndexError ("cannot take non-empty indices from an empty axis" )
171+ res_sh = x_sh [:axis ] + inds .shape + x_sh [axis_end :]
172+
173+ if out is not None :
174+ out = dpnp .get_usm_ndarray (out )
175+
176+ if not out .flags .writable :
177+ raise ValueError ("provided `out` array is read-only" )
178+
179+ if out .shape != res_sh :
180+ raise ValueError (
181+ "The shape of input and output arrays are inconsistent. "
182+ f"Expected output shape is { res_sh } , got { out .shape } "
183+ )
184+
185+ if x .dtype != out .dtype :
186+ raise TypeError (
187+ f"Output array of type { x .dtype } is needed, " f"got { out .dtype } "
188+ )
189+
190+ if dpu .get_execution_queue ((q , out .sycl_queue )) is None :
191+ raise dpu .ExecutionPlacementError (
192+ "Input and output allocation queues are not compatible"
193+ )
194+
195+ if ti ._array_overlap (x , out ):
196+ # Allocate a temporary buffer to avoid memory overlapping.
197+ out = dpt .empty_like (out )
198+ else :
199+ out = dpt .empty (res_sh , dtype = x .dtype , usm_type = usm_type , sycl_queue = q )
200+
201+ _manager = dpu .SequentialOrderManager [q ]
202+ dep_evs = _manager .submitted_events
203+
204+ h_ev , take_ev = ti ._take (
205+ src = x ,
206+ ind = (inds ,),
207+ dst = out ,
208+ axis_start = axis ,
209+ mode = mode ,
210+ sycl_queue = q ,
211+ depends = dep_evs ,
212+ )
213+ _manager .add_event_pair (h_ev , take_ev )
214+
215+ return out
216+
217+
218+ def compress (condition , a , axis = None , out = None ):
219+ """
220+ Return selected slices of an array along given axis.
221+
222+ A slice of `a` is returned for each index along `axis` where `condition`
223+ is ``True``.
224+
225+ For full documentation refer to :obj:`numpy.choose`.
226+
227+ Parameters
228+ ----------
229+ condition : {array_like, dpnp.ndarray, usm_ndarray}
230+ Array that selects which entries to extract. If the length of
231+ `condition` is less than the size of `a` along `axis`, then
232+ the output is truncated to the length of `condition`.
233+ a : {dpnp.ndarray, usm_ndarray}
234+ Array to extract from.
235+ axis : {None, int}, optional
236+ Axis along which to extract slices. If ``None``, works over the
237+ flattened array.
238+ Default: ``None``.
239+ out : {None, dpnp.ndarray, usm_ndarray}, optional
240+ If provided, the result will be placed in this array. It should
241+ be of the appropriate shape and dtype.
242+ Default: ``None``.
243+
244+ Returns
245+ -------
246+ out : dpnp.ndarray
247+ A copy of the slices of `a` where `condition` is ``True``.
248+
249+ See also
250+ --------
251+ :obj:`dpnp.take` : Take elements from an array along an axis.
252+ :obj:`dpnp.choose` : Construct an array from an index array and a set of
253+ arrays to choose from.
254+ :obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
255+ :obj:`dpnp.diagonal` : Return specified diagonals.
256+ :obj:`dpnp.select` : Return an array drawn from elements in `choicelist`,
257+ depending on conditions.
258+ :obj:`dpnp.ndarray.compress` : Equivalent method.
259+ :obj:`dpnp.extract` : Equivalent function when working on 1-D arrays.
260+
261+ Examples
262+ --------
263+ >>> import numpy as np
264+ >>> a = np.array([[1, 2], [3, 4], [5, 6]])
265+ >>> a
266+ array([[1, 2],
267+ [3, 4],
268+ [5, 6]])
269+ >>> np.compress([0, 1], a, axis=0)
270+ array([[3, 4]])
271+ >>> np.compress([False, True, True], a, axis=0)
272+ array([[3, 4],
273+ [5, 6]])
274+ >>> np.compress([False, True], a, axis=1)
275+ array([[2],
276+ [4],
277+ [6]])
278+
279+ Working on the flattened array does not return slices along an axis but
280+ selects elements.
281+
282+ >>> np.compress([False, True], a)
283+ array([2])
284+ """
285+
286+ dpnp .check_supported_arrays_type (a )
287+ if axis is None :
288+ if a .ndim != 1 :
289+ a = dpnp .ravel (a )
290+ axis = 0
291+ axis = normalize_axis_index (operator .index (axis ), a .ndim )
292+
293+ a_ary = dpnp .get_usm_ndarray (a )
294+ cond_ary = dpnp .as_usm_ndarray (
295+ condition ,
296+ dtype = dpnp .bool ,
297+ usm_type = a_ary .usm_type ,
298+ sycl_queue = a_ary .sycl_queue ,
299+ )
300+
301+ if not cond_ary .ndim == 1 :
302+ raise ValueError (
303+ "`condition` must be a 1-D array or un-nested sequence"
304+ )
305+
306+ res_usm_type , exec_q = get_usm_allocations ([a_ary , cond_ary ])
307+
308+ # _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
309+ inds = _nonzero_impl (cond_ary )
310+
311+ res = _take_index (a_ary , inds [0 ], axis , exec_q , res_usm_type , out = out )
312+
313+ return dpnp .get_result_array (res , out = out )
314+
315+
158316def diag_indices (n , ndim = 2 , device = None , usm_type = "device" , sycl_queue = None ):
159317 """
160318 Return the indices to access the main diagonal of an array.
@@ -1806,8 +1964,8 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
18061964
18071965 """
18081966
1809- if mode not in ( "wrap" , "clip" ):
1810- raise ValueError ( f"` mode` must be 'wrap' or 'clip', but got ` { mode } `." )
1967+ # sets mode to 0 for "wrap" and 1 for "clip", raises otherwise
1968+ mode = _get_indexing_mode ( mode )
18111969
18121970 usm_a = dpnp .get_usm_ndarray (a )
18131971 if not dpnp .is_supported_array_type (indices ):
@@ -1817,34 +1975,28 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
18171975 else :
18181976 usm_ind = dpnp .get_usm_ndarray (indices )
18191977
1978+ res_usm_type , exec_q = get_usm_allocations ([usm_a , usm_ind ])
1979+
18201980 a_ndim = a .ndim
18211981 if axis is None :
1822- res_shape = usm_ind .shape
1823-
18241982 if a_ndim > 1 :
1825- # dpt.take requires flattened input array
1983+ # flatten input array
18261984 usm_a = dpt .reshape (usm_a , - 1 )
1985+ axis = 0
18271986 elif a_ndim == 0 :
18281987 axis = normalize_axis_index (operator .index (axis ), 1 )
1829- res_shape = usm_ind .shape
18301988 else :
18311989 axis = normalize_axis_index (operator .index (axis ), a_ndim )
1832- a_sh = a .shape
1833- res_shape = a_sh [:axis ] + usm_ind .shape + a_sh [axis + 1 :]
1834-
1835- if usm_ind .ndim != 1 :
1836- # dpt.take supports only 1-D array of indices
1837- usm_ind = dpt .reshape (usm_ind , - 1 )
18381990
18391991 if not dpnp .issubdtype (usm_ind .dtype , dpnp .integer ):
18401992 # dpt.take supports only integer dtype for array of indices
18411993 usm_ind = dpt .astype (usm_ind , dpnp .intp , copy = False , casting = "safe" )
18421994
1843- usm_res = dpt .take (usm_a , usm_ind , axis = axis , mode = mode )
1995+ usm_res = _take_index (
1996+ usm_a , usm_ind , axis , exec_q , res_usm_type , out = out , mode = mode
1997+ )
18441998
1845- # need to reshape the result if shape of indices array was changed
1846- result = dpnp .reshape (usm_res , res_shape )
1847- return dpnp .get_result_array (result , out )
1999+ return dpnp .get_result_array (usm_res , out = out )
18482000
18492001
18502002def take_along_axis (a , indices , axis , mode = "wrap" ):
0 commit comments