1414
1515from ._helpers import _check_device , _is_numpy_array , get_namespace
1616
17- # Basic renames
18- def acos (x , / , xp ):
19- return xp .arccos (x )
20-
21- def acosh (x , / , xp ):
22- return xp .arccosh (x )
23-
24- def asin (x , / , xp ):
25- return xp .arcsin (x )
26-
27- def asinh (x , / , xp ):
28- return xp .arcsinh (x )
29-
30- def atan (x , / , xp ):
31- return xp .arctan (x )
32-
33- def atan2 (x1 , x2 , / , xp ):
34- return xp .arctan2 (x1 , x2 )
35-
36- def atanh (x , / , xp ):
37- return xp .arctanh (x )
38-
39- def bitwise_left_shift (x1 , x2 , / , xp ):
40- return xp .left_shift (x1 , x2 )
41-
42- def bitwise_invert (x , / , xp ):
43- return xp .invert (x )
44-
45- def bitwise_right_shift (x1 , x2 , / , xp ):
46- return xp .right_shift (x1 , x2 )
47-
48- def concat (arrays : Union [Tuple [ndarray , ...], List [ndarray ]], / , xp , * , axis : Optional [int ] = 0 ) -> ndarray :
49- return xp .concatenate (arrays , axis = axis )
50-
51- def pow (x1 , x2 , / , xp ):
52- return xp .power (x1 , x2 )
53-
5417# These functions are modified from the NumPy versions.
5518
5619def arange (
@@ -62,25 +25,28 @@ def arange(
6225 xp ,
6326 dtype : Optional [Dtype ] = None ,
6427 device : Optional [Device ] = None ,
28+ ** kwargs
6529) -> ndarray :
6630 _check_device (xp , device )
67- return xp .arange (start , stop = stop , step = step , dtype = dtype )
31+ return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
6832
6933def empty (
7034 shape : Union [int , Tuple [int , ...]],
7135 xp ,
7236 * ,
7337 dtype : Optional [Dtype ] = None ,
7438 device : Optional [Device ] = None ,
39+ ** kwargs
7540) -> ndarray :
7641 _check_device (xp , device )
77- return xp .empty (shape , dtype = dtype )
42+ return xp .empty (shape , dtype = dtype , ** kwargs )
7843
7944def empty_like (
80- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
45+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
46+ ** kwargs
8147) -> ndarray :
8248 _check_device (xp , device )
83- return xp .empty_like (x , dtype = dtype )
49+ return xp .empty_like (x , dtype = dtype , ** kwargs )
8450
8551def eye (
8652 n_rows : int ,
@@ -91,9 +57,10 @@ def eye(
9157 k : int = 0 ,
9258 dtype : Optional [Dtype ] = None ,
9359 device : Optional [Device ] = None ,
60+ ** kwargs ,
9461) -> ndarray :
9562 _check_device (xp , device )
96- return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype )
63+ return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
9764
9865def full (
9966 shape : Union [int , Tuple [int , ...]],
@@ -102,9 +69,10 @@ def full(
10269 * ,
10370 dtype : Optional [Dtype ] = None ,
10471 device : Optional [Device ] = None ,
72+ ** kwargs ,
10573) -> ndarray :
10674 _check_device (xp , device )
107- return xp .full (shape , fill_value , dtype = dtype )
75+ return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
10876
10977def full_like (
11078 x : ndarray ,
@@ -114,9 +82,10 @@ def full_like(
11482 xp ,
11583 dtype : Optional [Dtype ] = None ,
11684 device : Optional [Device ] = None ,
85+ ** kwargs ,
11786) -> ndarray :
11887 _check_device (xp , device )
119- return xp .full_like (x , fill_value , dtype = dtype )
88+ return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
12089
12190def linspace (
12291 start : Union [int , float ],
@@ -128,41 +97,46 @@ def linspace(
12897 dtype : Optional [Dtype ] = None ,
12998 device : Optional [Device ] = None ,
13099 endpoint : bool = True ,
100+ ** kwargs ,
131101) -> ndarray :
132102 _check_device (xp , device )
133- return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint )
103+ return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
134104
135105def ones (
136106 shape : Union [int , Tuple [int , ...]],
137107 xp ,
138108 * ,
139109 dtype : Optional [Dtype ] = None ,
140110 device : Optional [Device ] = None ,
111+ ** kwargs ,
141112) -> ndarray :
142113 _check_device (xp , device )
143- return xp .ones (shape , dtype = dtype )
114+ return xp .ones (shape , dtype = dtype , ** kwargs )
144115
145116def ones_like (
146- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
117+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
118+ ** kwargs ,
147119) -> ndarray :
148120 _check_device (xp , device )
149- return xp .ones_like (x , dtype = dtype )
121+ return xp .ones_like (x , dtype = dtype , ** kwargs )
150122
151123def zeros (
152124 shape : Union [int , Tuple [int , ...]],
153125 xp ,
154126 * ,
155127 dtype : Optional [Dtype ] = None ,
156128 device : Optional [Device ] = None ,
129+ ** kwargs ,
157130) -> ndarray :
158131 _check_device (xp , device )
159- return xp .zeros (shape , dtype = dtype )
132+ return xp .zeros (shape , dtype = dtype , ** kwargs )
160133
161134def zeros_like (
162- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
135+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
136+ ** kwargs ,
163137) -> ndarray :
164138 _check_device (xp , device )
165- return xp .zeros_like (x , dtype = dtype )
139+ return xp .zeros_like (x , dtype = dtype , ** kwargs )
166140
167141# np.unique() is split into four functions in the array API:
168142# unique_all, unique_counts, unique_inverse, and unique_values (this is done
@@ -256,8 +230,9 @@ def std(
256230 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
257231 correction : Union [int , float ] = 0.0 , # correction instead of ddof
258232 keepdims : bool = False ,
233+ ** kwargs ,
259234) -> ndarray :
260- return xp .std (x , axis = axis , ddof = correction , keepdims = keepdims )
235+ return xp .std (x , axis = axis , ddof = correction , keepdims = keepdims , ** kwargs )
261236
262237def var (
263238 x : ndarray ,
@@ -267,8 +242,9 @@ def var(
267242 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
268243 correction : Union [int , float ] = 0.0 , # correction instead of ddof
269244 keepdims : bool = False ,
245+ ** kwargs ,
270246) -> ndarray :
271- return xp .var (x , axis = axis , ddof = correction , keepdims = keepdims )
247+ return xp .var (x , axis = axis , ddof = correction , keepdims = keepdims , ** kwargs )
272248
273249# Unlike transpose(), the axes argument to permute_dims() is required.
274250def permute_dims (x : ndarray , / , axes : Tuple [int , ...], xp ) -> ndarray :
@@ -292,6 +268,7 @@ def _asarray(
292268 device : Optional [Device ] = None ,
293269 copy : "Optional[Union[bool, np._CopyMode]]" = None ,
294270 namespace = None ,
271+ ** kwargs ,
295272) -> ndarray :
296273 """
297274 Array API compatibility wrapper for asarray().
@@ -333,33 +310,39 @@ def _asarray(
333310 return xp .array (obj , copy = True , dtype = dtype )
334311 return obj
335312
336- return xp .asarray (obj , dtype = dtype )
313+ return xp .asarray (obj , dtype = dtype , ** kwargs )
337314
338315# xp.reshape calls the keyword argument 'newshape' instead of 'shape'
339- def reshape (x : ndarray , / , shape : Tuple [int , ...], xp , copy : Optional [bool ] = None ) -> ndarray :
316+ def reshape (x : ndarray ,
317+ / ,
318+ shape : Tuple [int , ...],
319+ xp , copy : Optional [bool ] = None ,
320+ ** kwargs ) -> ndarray :
340321 if copy is True :
341322 x = x .copy ()
342323 elif copy is False :
343324 x .shape = shape
344325 return x
345- return xp .reshape (x , shape )
326+ return xp .reshape (x , shape , ** kwargs )
346327
347328# The descending keyword is new in sort and argsort, and 'kind' replaced with
348329# 'stable'
349330def argsort (
350- x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True
331+ x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True ,
332+ ** kwargs ,
351333) -> ndarray :
352334 # Note: this keyword argument is different, and the default is different.
353335 kind = "stable" if stable else "quicksort"
354336 if not descending :
355- res = xp .argsort (x , axis = axis , kind = kind )
337+ res = xp .argsort (x , axis = axis , kind = kind , ** kwargs )
356338 else :
357339 # As NumPy has no native descending sort, we imitate it here. Note that
358340 # simply flipping the results of xp.argsort(x, ...) would not
359341 # respect the relative order like it would in native descending sorts.
360342 res = xp .flip (
361343 xp .argsort (xp .flip (x , axis = axis ), axis = axis , kind = kind ),
362344 axis = axis ,
345+ ** kwargs ,
363346 )
364347 # Rely on flip()/argsort() to validate axis
365348 normalised_axis = axis if axis >= 0 else x .ndim + axis
@@ -368,11 +351,12 @@ def argsort(
368351 return res
369352
370353def sort (
371- x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True
354+ x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True ,
355+ ** kwargs ,
372356) -> ndarray :
373357 # Note: this keyword argument is different, and the default is different.
374358 kind = "stable" if stable else "quicksort"
375- res = xp .sort (x , axis = axis , kind = kind )
359+ res = xp .sort (x , axis = axis , kind = kind , ** kwargs )
376360 if descending :
377361 res = xp .flip (res , axis = axis )
378362 return res
@@ -386,11 +370,12 @@ def sum(
386370 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
387371 dtype : Optional [Dtype ] = None ,
388372 keepdims : bool = False ,
373+ ** kwargs ,
389374) -> ndarray :
390375 # `xp.sum` already upcasts integers, but not floats
391376 if dtype is None and x .dtype == xp .float32 :
392377 dtype = xp .float64
393- return xp .sum (x , axis = axis , dtype = dtype , keepdims = keepdims )
378+ return xp .sum (x , axis = axis , dtype = dtype , keepdims = keepdims , ** kwargs )
394379
395380def prod (
396381 x : ndarray ,
@@ -400,32 +385,30 @@ def prod(
400385 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
401386 dtype : Optional [Dtype ] = None ,
402387 keepdims : bool = False ,
388+ ** kwargs ,
403389) -> ndarray :
404390 if dtype is None and x .dtype == xp .float32 :
405391 dtype = xp .float64
406- return xp .prod (x , dtype = dtype , axis = axis , keepdims = keepdims )
392+ return xp .prod (x , dtype = dtype , axis = axis , keepdims = keepdims , ** kwargs )
407393
408394# ceil, floor, and trunc return integers for integer inputs
409395
410- def ceil (x : ndarray , / , xp ) -> ndarray :
396+ def ceil (x : ndarray , / , xp , ** kwargs ) -> ndarray :
411397 if xp .issubdtype (x .dtype , xp .integer ):
412398 return x
413- return xp .ceil (x )
399+ return xp .ceil (x , ** kwargs )
414400
415- def floor (x : ndarray , / , xp ) -> ndarray :
401+ def floor (x : ndarray , / , xp , ** kwargs ) -> ndarray :
416402 if xp .issubdtype (x .dtype , xp .integer ):
417403 return x
418- return xp .floor (x )
404+ return xp .floor (x , ** kwargs )
419405
420- def trunc (x : ndarray , / , xp ) -> ndarray :
406+ def trunc (x : ndarray , / , xp , ** kwargs ) -> ndarray :
421407 if xp .issubdtype (x .dtype , xp .integer ):
422408 return x
423- return xp .trunc (x )
424-
425- __all__ = ['acos' , 'acosh' , 'asin' , 'asinh' , 'atan' , 'atan2' , 'atanh' ,
426- 'bitwise_left_shift' , 'bitwise_invert' , 'bitwise_right_shift' ,
427- 'concat' , 'pow' , 'UniqueAllResult' , 'UniqueCountsResult' ,
428- 'UniqueInverseResult' , 'unique_all' , 'unique_counts' ,
429- 'unique_inverse' , 'unique_values' , 'astype' , 'std' , 'var' ,
430- 'permute_dims' , 'reshape' , 'argsort' , 'sort' , 'sum' , 'prod' ,
431- 'ceil' , 'floor' , 'trunc' ]
409+ return xp .trunc (x , ** kwargs )
410+
411+ __all__ = ['UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
412+ 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
413+ 'astype' , 'std' , 'var' , 'permute_dims' , 'reshape' , 'argsort' ,
414+ 'sort' , 'sum' , 'prod' , 'ceil' , 'floor' , 'trunc' ]
0 commit comments