3939
4040import dpctl .tensor as dpt
4141import numpy
42- from dpctl .tensor ._numpy_helper import (
43- normalize_axis_index ,
44- normalize_axis_tuple ,
45- )
42+ from dpctl .tensor ._numpy_helper import normalize_axis_index
4643
4744import dpnp
4845
4946# pylint: disable=no-name-in-module
5047from .dpnp_algo import dpnp_correlate
51- from .dpnp_array import dpnp_array
5248from .dpnp_utils import call_origin , get_usm_allocations
5349from .dpnp_utils .dpnp_utils_reduction import dpnp_wrap_reduction_call
54- from .dpnp_utils .dpnp_utils_statistics import dpnp_cov
50+ from .dpnp_utils .dpnp_utils_statistics import dpnp_cov , dpnp_median
5551
5652__all__ = [
5753 "amax" ,
@@ -113,22 +109,6 @@ def _count_reduce_items(arr, axis, where=True):
113109 return items
114110
115111
116- def _flatten_array_along_axes (arr , axes_to_flatten ):
117- """Flatten an array along a specific set of axes."""
118-
119- axes_to_keep = (
120- axis for axis in range (arr .ndim ) if axis not in axes_to_flatten
121- )
122-
123- # Move the axes_to_flatten to the front
124- arr_moved = dpnp .moveaxis (arr , axes_to_flatten , range (len (axes_to_flatten )))
125-
126- new_shape = (- 1 ,) + tuple (arr .shape [axis ] for axis in axes_to_keep )
127- flattened_arr = arr_moved .reshape (new_shape )
128-
129- return flattened_arr
130-
131-
132112def _get_comparison_res_dt (a , _dtype , _out ):
133113 """Get a data type used by dpctl for result array in comparison function."""
134114
@@ -765,7 +745,7 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
765745 preserve the contents of the input array. Treat the input as undefined,
766746 but it will probably be fully or partially sorted.
767747 Default: ``False``.
768- keepdims : {None, bool} , optional
748+ keepdims : bool, optional
769749 If ``True``, the reduced axes (dimensions) are included in the result
770750 as singleton dimensions, so that the returned array remains
771751 compatible with the input array according to Array Broadcasting
@@ -775,7 +755,7 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
775755
776756 Returns
777757 -------
778- dpnp.median : dpnp.ndarray
758+ out : dpnp.ndarray
779759 A new array holding the result. If `a` has a floating-point data type,
780760 the returned array will have the same data type as `a`. If `a` has a
781761 boolean or integral data type, the returned array will have the
@@ -808,20 +788,20 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
808788 >>> np.median(a, axis=0)
809789 array([6.5, 4.5, 2.5])
810790 >>> np.median(a, axis=1)
811- array([7., 2.])
791+ array([7., 2.])
812792 >>> np.median(a, axis=(0, 1))
813793 array(3.5)
814794
815795 >>> m = np.median(a, axis=0)
816796 >>> out = np.zeros_like(m)
817797 >>> np.median(a, axis=0, out=m)
818- array([6.5, 4.5, 2.5])
798+ array([6.5, 4.5, 2.5])
819799 >>> m
820- array([6.5, 4.5, 2.5])
800+ array([6.5, 4.5, 2.5])
821801
822802 >>> b = a.copy()
823803 >>> np.median(b, axis=1, overwrite_input=True)
824- array([7., 2.])
804+ array([7., 2.])
825805 >>> assert not np.all(a==b)
826806 >>> b = a.copy()
827807 >>> np.median(b, axis=None, overwrite_input=True)
@@ -831,62 +811,9 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
831811 """
832812
833813 dpnp .check_supported_arrays_type (a )
834- a_ndim = a .ndim
835- a_shape = a .shape
836- _axis = range (a_ndim ) if axis is None else axis
837- _axis = normalize_axis_tuple (_axis , a_ndim )
838-
839- if isinstance (axis , (tuple , list )):
840- if len (axis ) == 1 :
841- axis = axis [0 ]
842- else :
843- # Need to flatten if `axis` is a sequence of axes since `dpnp.sort`
844- # only accepts integer `axis`
845- # Note that the output of _flatten_array_along_axes is not
846- # necessarily a view of the input since `reshape` is used there.
847- # If this is the case, using overwrite_input is meaningless
848- a = _flatten_array_along_axes (a , _axis )
849- axis = 0
850-
851- if overwrite_input :
852- if axis is None :
853- a_sorted = dpnp .ravel (a )
854- a_sorted .sort ()
855- else :
856- if isinstance (a , dpt .usm_ndarray ):
857- # dpnp.ndarray.sort only works with dpnp_array
858- a = dpnp_array ._create_from_usm_ndarray (a )
859- a .sort (axis = axis )
860- a_sorted = a
861- else :
862- a_sorted = dpnp .sort (a , axis = axis )
863-
864- if axis is None :
865- axis = 0
866- indexer = [slice (None )] * a_sorted .ndim
867- index , remainder = divmod (a_sorted .shape [axis ], 2 )
868- if remainder == 1 :
869- # index with slice to allow mean (below) to work
870- indexer [axis ] = slice (index , index + 1 )
871- else :
872- indexer [axis ] = slice (index - 1 , index + 1 )
873-
874- # Use `mean` in odd and even case to coerce data type and use `out` array
875- res = dpnp .mean (a_sorted [tuple (indexer )], axis = axis , out = out )
876- nan_mask = dpnp .isnan (a_sorted ).any (axis = axis )
877- if nan_mask .any ():
878- res [nan_mask ] = dpnp .nan
879-
880- if keepdims :
881- # We can't use dpnp.mean(..., keepdims) and dpnp.any(..., keepdims)
882- # above because of the reshape hack might have been used in
883- # `_flatten_array_along_axes` to handle cases when axis is a tuple.
884- res_shape = list (a_shape )
885- for i in _axis :
886- res_shape [i ] = 1
887- res = res .reshape (tuple (res_shape ))
888-
889- return res
814+ return dpnp_median (
815+ a , axis , out , overwrite_input , keepdims , ignore_nan = False
816+ )
890817
891818
892819def min (a , axis = None , out = None , keepdims = False , initial = None , where = True ):
0 commit comments