@@ -25,25 +25,28 @@ def arange(
2525 xp ,
2626 dtype : Optional [Dtype ] = None ,
2727 device : Optional [Device ] = None ,
28+ ** kwargs
2829) -> ndarray :
2930 _check_device (xp , device )
30- return xp .arange (start , stop = stop , step = step , dtype = dtype )
31+ return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
3132
3233def empty (
3334 shape : Union [int , Tuple [int , ...]],
3435 xp ,
3536 * ,
3637 dtype : Optional [Dtype ] = None ,
3738 device : Optional [Device ] = None ,
39+ ** kwargs
3840) -> ndarray :
3941 _check_device (xp , device )
40- return xp .empty (shape , dtype = dtype )
42+ return xp .empty (shape , dtype = dtype , ** kwargs )
4143
4244def empty_like (
43- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
45+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
46+ ** kwargs
4447) -> ndarray :
4548 _check_device (xp , device )
46- return xp .empty_like (x , dtype = dtype )
49+ return xp .empty_like (x , dtype = dtype , ** kwargs )
4750
4851def eye (
4952 n_rows : int ,
@@ -54,9 +57,10 @@ def eye(
5457 k : int = 0 ,
5558 dtype : Optional [Dtype ] = None ,
5659 device : Optional [Device ] = None ,
60+ ** kwargs ,
5761) -> ndarray :
5862 _check_device (xp , device )
59- return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype )
63+ return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
6064
6165def full (
6266 shape : Union [int , Tuple [int , ...]],
@@ -65,9 +69,10 @@ def full(
6569 * ,
6670 dtype : Optional [Dtype ] = None ,
6771 device : Optional [Device ] = None ,
72+ ** kwargs ,
6873) -> ndarray :
6974 _check_device (xp , device )
70- return xp .full (shape , fill_value , dtype = dtype )
75+ return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
7176
7277def full_like (
7378 x : ndarray ,
@@ -77,9 +82,10 @@ def full_like(
7782 xp ,
7883 dtype : Optional [Dtype ] = None ,
7984 device : Optional [Device ] = None ,
85+ ** kwargs ,
8086) -> ndarray :
8187 _check_device (xp , device )
82- return xp .full_like (x , fill_value , dtype = dtype )
88+ return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
8389
8490def linspace (
8591 start : Union [int , float ],
@@ -91,41 +97,46 @@ def linspace(
9197 dtype : Optional [Dtype ] = None ,
9298 device : Optional [Device ] = None ,
9399 endpoint : bool = True ,
100+ ** kwargs ,
94101) -> ndarray :
95102 _check_device (xp , device )
96- return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint )
103+ return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
97104
98105def ones (
99106 shape : Union [int , Tuple [int , ...]],
100107 xp ,
101108 * ,
102109 dtype : Optional [Dtype ] = None ,
103110 device : Optional [Device ] = None ,
111+ ** kwargs ,
104112) -> ndarray :
105113 _check_device (xp , device )
106- return xp .ones (shape , dtype = dtype )
114+ return xp .ones (shape , dtype = dtype , ** kwargs )
107115
108116def ones_like (
109- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
117+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
118+ ** kwargs ,
110119) -> ndarray :
111120 _check_device (xp , device )
112- return xp .ones_like (x , dtype = dtype )
121+ return xp .ones_like (x , dtype = dtype , ** kwargs )
113122
114123def zeros (
115124 shape : Union [int , Tuple [int , ...]],
116125 xp ,
117126 * ,
118127 dtype : Optional [Dtype ] = None ,
119128 device : Optional [Device ] = None ,
129+ ** kwargs ,
120130) -> ndarray :
121131 _check_device (xp , device )
122- return xp .zeros (shape , dtype = dtype )
132+ return xp .zeros (shape , dtype = dtype , ** kwargs )
123133
124134def zeros_like (
125- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
135+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
136+ ** kwargs ,
126137) -> ndarray :
127138 _check_device (xp , device )
128- return xp .zeros_like (x , dtype = dtype )
139+ return xp .zeros_like (x , dtype = dtype , ** kwargs )
129140
130141# np.unique() is split into four functions in the array API:
131142# unique_all, unique_counts, unique_inverse, and unique_values (this is done
@@ -219,8 +230,9 @@ def std(
219230 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
220231 correction : Union [int , float ] = 0.0 , # correction instead of ddof
221232 keepdims : bool = False ,
233+ ** kwargs ,
222234) -> ndarray :
223- return xp .std (x , axis = axis , ddof = correction , keepdims = keepdims )
235+ return xp .std (x , axis = axis , ddof = correction , keepdims = keepdims , ** kwargs )
224236
225237def var (
226238 x : ndarray ,
@@ -230,8 +242,9 @@ def var(
230242 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
231243 correction : Union [int , float ] = 0.0 , # correction instead of ddof
232244 keepdims : bool = False ,
245+ ** kwargs ,
233246) -> ndarray :
234- return xp .var (x , axis = axis , ddof = correction , keepdims = keepdims )
247+ return xp .var (x , axis = axis , ddof = correction , keepdims = keepdims , ** kwargs )
235248
236249# Unlike transpose(), the axes argument to permute_dims() is required.
237250def permute_dims (x : ndarray , / , axes : Tuple [int , ...], xp ) -> ndarray :
@@ -255,6 +268,7 @@ def _asarray(
255268 device : Optional [Device ] = None ,
256269 copy : "Optional[Union[bool, np._CopyMode]]" = None ,
257270 namespace = None ,
271+ ** kwargs ,
258272) -> ndarray :
259273 """
260274 Array API compatibility wrapper for asarray().
@@ -296,33 +310,39 @@ def _asarray(
296310 return xp .array (obj , copy = True , dtype = dtype )
297311 return obj
298312
299- return xp .asarray (obj , dtype = dtype )
313+ return xp .asarray (obj , dtype = dtype , ** kwargs )
300314
301315# xp.reshape calls the keyword argument 'newshape' instead of 'shape'
302- def reshape (x : ndarray , / , shape : Tuple [int , ...], xp , copy : Optional [bool ] = None ) -> ndarray :
316+ def reshape (x : ndarray ,
317+ / ,
318+ shape : Tuple [int , ...],
319+ xp , copy : Optional [bool ] = None ,
320+ ** kwargs ) -> ndarray :
303321 if copy is True :
304322 x = x .copy ()
305323 elif copy is False :
306324 x .shape = shape
307325 return x
308- return xp .reshape (x , shape )
326+ return xp .reshape (x , shape , ** kwargs )
309327
310328# The descending keyword is new in sort and argsort, and 'kind' replaced with
311329# 'stable'
312330def argsort (
313- x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True
331+ x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True ,
332+ ** kwargs ,
314333) -> ndarray :
315334 # Note: this keyword argument is different, and the default is different.
316335 kind = "stable" if stable else "quicksort"
317336 if not descending :
318- res = xp .argsort (x , axis = axis , kind = kind )
337+ res = xp .argsort (x , axis = axis , kind = kind , ** kwargs )
319338 else :
320339 # As NumPy has no native descending sort, we imitate it here. Note that
321340 # simply flipping the results of xp.argsort(x, ...) would not
322341 # respect the relative order like it would in native descending sorts.
323342 res = xp .flip (
324343 xp .argsort (xp .flip (x , axis = axis ), axis = axis , kind = kind ),
325344 axis = axis ,
345+ ** kwargs ,
326346 )
327347 # Rely on flip()/argsort() to validate axis
328348 normalised_axis = axis if axis >= 0 else x .ndim + axis
@@ -331,11 +351,12 @@ def argsort(
331351 return res
332352
333353def sort (
334- x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True
354+ x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True ,
355+ ** kwargs ,
335356) -> ndarray :
336357 # Note: this keyword argument is different, and the default is different.
337358 kind = "stable" if stable else "quicksort"
338- res = xp .sort (x , axis = axis , kind = kind )
359+ res = xp .sort (x , axis = axis , kind = kind , ** kwargs )
339360 if descending :
340361 res = xp .flip (res , axis = axis )
341362 return res
@@ -349,11 +370,12 @@ def sum(
349370 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
350371 dtype : Optional [Dtype ] = None ,
351372 keepdims : bool = False ,
373+ ** kwargs ,
352374) -> ndarray :
353375 # `xp.sum` already upcasts integers, but not floats
354376 if dtype is None and x .dtype == xp .float32 :
355377 dtype = xp .float64
356- return xp .sum (x , axis = axis , dtype = dtype , keepdims = keepdims )
378+ return xp .sum (x , axis = axis , dtype = dtype , keepdims = keepdims , ** kwargs )
357379
358380def prod (
359381 x : ndarray ,
@@ -363,27 +385,28 @@ def prod(
363385 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
364386 dtype : Optional [Dtype ] = None ,
365387 keepdims : bool = False ,
388+ ** kwargs ,
366389) -> ndarray :
367390 if dtype is None and x .dtype == xp .float32 :
368391 dtype = xp .float64
369- return xp .prod (x , dtype = dtype , axis = axis , keepdims = keepdims )
392+ return xp .prod (x , dtype = dtype , axis = axis , keepdims = keepdims , ** kwargs )
370393
371394# ceil, floor, and trunc return integers for integer inputs
372395
373- def ceil (x : ndarray , / , xp ) -> ndarray :
396+ def ceil (x : ndarray , / , xp , ** kwargs ) -> ndarray :
374397 if xp .issubdtype (x .dtype , xp .integer ):
375398 return x
376- return xp .ceil (x )
399+ return xp .ceil (x , ** kwargs )
377400
378- def floor (x : ndarray , / , xp ) -> ndarray :
401+ def floor (x : ndarray , / , xp , ** kwargs ) -> ndarray :
379402 if xp .issubdtype (x .dtype , xp .integer ):
380403 return x
381- return xp .floor (x )
404+ return xp .floor (x , ** kwargs )
382405
383- def trunc (x : ndarray , / , xp ) -> ndarray :
406+ def trunc (x : ndarray , / , xp , ** kwargs ) -> ndarray :
384407 if xp .issubdtype (x .dtype , xp .integer ):
385408 return x
386- return xp .trunc (x )
409+ return xp .trunc (x , ** kwargs )
387410
388411__all__ = ['UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
389412 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
0 commit comments