@@ -47,7 +47,17 @@ def _lerp(a, b, *, t, dtype, out=None):
4747
4848
4949def quantile_or_topk (
50- array , inv_idx , * , q = None , k = None , axis , skipna , group_idx , dtype = None , out = None
50+ array ,
51+ inv_idx ,
52+ * ,
53+ q = None ,
54+ k = None ,
55+ axis ,
56+ skipna ,
57+ group_idx ,
58+ dtype = None ,
59+ out = None ,
60+ fill_value = None ,
5161):
5262 assert q or k
5363
@@ -81,9 +91,8 @@ def quantile_or_topk(
8191
8292 param = q or k
8393 if k is not None :
84- assert k > 0
8594 is_scalar_param = False
86- param = np .arange (k )
95+ param = np .arange (abs ( k ) )
8796 else :
8897 is_scalar_param = is_scalar (q )
8998 param = np .atleast_1d (param )
@@ -111,10 +120,10 @@ def quantile_or_topk(
111120 kth = np .unique (np .concatenate ([lo_ .reshape (- 1 ), hi_ .reshape (- 1 )]))
112121
113122 else :
114- virtual_index = ( actual_sizes - k ) + inv_idx [: - 1 ]
123+ virtual_index = inv_idx [: - 1 ] + (( actual_sizes - k ) if k > 0 else abs ( k ) - 1 )
115124 kth = np .unique (virtual_index )
116125 kth = kth [kth > 0 ]
117- k_offset = np . arange ( k ). reshape ((k ,) + (1 ,) * virtual_index .ndim )
126+ k_offset = param . reshape ((abs ( k ) ,) + (1 ,) * virtual_index .ndim )
118127 lo_ = k_offset + virtual_index [np .newaxis , ...]
119128
120129 # partition the complex array in-place
@@ -137,15 +146,12 @@ def quantile_or_topk(
137146 gamma = np .broadcast_to (virtual_index , idxshape ) - lo_
138147 result = _lerp (loval , hival , t = gamma , out = out , dtype = dtype )
139148 else :
140- import ipdb
141-
142- ipdb .set_trace ()
143149 result = loval
144- result [lo_ < 0 ] = np . nan
150+ result [lo_ < 0 ] = fill_value
145151 if not skipna and np .any (nanmask ):
146- result [..., nanmask ] = np . nan
152+ result [..., nanmask ] = fill_value
147153 if k is not None :
148- result = result .astype (array . dtype , copy = False )
154+ result = result .astype (dtype , copy = False )
149155 np .copyto (out , result )
150156 return result
151157
@@ -175,9 +181,10 @@ def _np_grouped_op(
175181 if not q and not k :
176182 out = np .full (array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
177183 else :
178- nq = len (np .atleast_1d (q )) if q is not None else k
184+ nq = len (np .atleast_1d (q )) if q is not None else abs ( k )
179185 out = np .full ((nq ,) + array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
180186 kwargs ["group_idx" ] = group_idx
187+ kwargs ["fill_value" ] = fill_value
181188
182189 if (len (uniques ) == size ) and (uniques == np .arange (size , like = array )).all ():
183190 # The previous version of this if condition
0 commit comments