@@ -142,12 +142,14 @@ def absolute(x1, **kwargs):
142142
143143 """
144144
145- is_input_dparray = isinstance (x1 , dparray )
146-
147- if not use_origin_backend (x1 ) and is_input_dparray and x1 .ndim != 0 and not kwargs :
148- result = dpnp_absolute (x1 )
145+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
146+ if x1_desc and not kwargs :
147+ if not x1_desc .ndim :
148+ pass
149+ else :
150+ result = dpnp_absolute (x1_desc )
149151
150- return result
152+ return result
151153
152154 return call_origin (numpy .absolute , x1 , ** kwargs )
153155
@@ -236,15 +238,14 @@ def around(x1, decimals=0, out=None):
236238
237239 """
238240
239- if not use_origin_backend (x1 ):
240- if not isinstance (x1 , dparray ):
241- pass
242- elif out is not None :
241+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
242+ if x1_desc :
243+ if out is not None :
243244 pass
244245 elif decimals != 0 :
245246 pass
246247 else :
247- return dpnp_around (x1 , decimals )
248+ return dpnp_around (x1_desc , decimals )
248249
249250 return call_origin (numpy .around , x1 , decimals = decimals , out = out )
250251
@@ -483,7 +484,7 @@ def cumsum(x1, **kwargs):
483484 return call_origin (numpy .cumsum , x1 , ** kwargs )
484485
485486
486- def diff (input , n = 1 , axis = - 1 , prepend = None , append = None ):
487+ def diff (x1 , n = 1 , axis = - 1 , prepend = None , append = None ):
487488 """
488489 Calculate the n-th discrete difference along the given axis.
489490
@@ -496,10 +497,9 @@ def diff(input, n=1, axis=-1, prepend=None, append=None):
496497 Otherwise the function will be executed sequentially on CPU.
497498 """
498499
499- if not use_origin_backend (input ):
500- if not isinstance (input , dparray ):
501- pass
502- elif not isinstance (n , int ):
500+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
501+ if x1_desc :
502+ if not isinstance (n , int ):
503503 pass
504504 elif n < 1 :
505505 pass
@@ -510,9 +510,9 @@ def diff(input, n=1, axis=-1, prepend=None, append=None):
510510 elif append is not None :
511511 pass
512512 else :
513- return dpnp_diff (input , n )
513+ return dpnp_diff (x1 , n )
514514
515- return call_origin (numpy .diff , input , n , axis , prepend , append )
515+ return call_origin (numpy .diff , x1 , n , axis , prepend , append )
516516
517517
518518def divide (x1 , x2 , dtype = None , out = None , where = True , ** kwargs ):
@@ -848,7 +848,7 @@ def fmod(x1, x2, dtype=None, out=None, where=True, **kwargs):
848848 return call_origin (numpy .fmod , x1 , x2 , dtype = dtype , out = out , where = where , ** kwargs )
849849
850850
851- def gradient (y1 , * varargs , ** kwargs ):
851+ def gradient (x1 , * varargs , ** kwargs ):
852852 """
853853 Return the gradient of an array.
854854
@@ -874,20 +874,20 @@ def gradient(y1, *varargs, **kwargs):
874874 [0.5, 0.75, 1.25, 1.75, 2.25, 2.5]
875875
876876 """
877- if not use_origin_backend ( y1 ) and not kwargs :
878- if not isinstance ( y1 , dparray ):
879- pass
880- elif len (varargs ) > 1 :
877+
878+ x1_desc = dpnp . get_dpnp_descriptor ( x1 )
879+ if x1_desc and not kwargs :
880+ if len (varargs ) > 1 :
881881 pass
882882 elif len (varargs ) == 1 and not isinstance (varargs [0 ], int ):
883883 pass
884884 else :
885885 if len (varargs ) == 0 :
886- return dpnp_gradient (y1 )
886+ return dpnp_gradient (x1 )
887887
888- return dpnp_gradient (y1 , varargs [0 ])
888+ return dpnp_gradient (x1 , varargs [0 ])
889889
890- return call_origin (numpy .gradient , y1 , * varargs , ** kwargs )
890+ return call_origin (numpy .gradient , x1 , * varargs , ** kwargs )
891891
892892
893893def maximum (x1 , x2 , dtype = None , out = None , where = True , ** kwargs ):
@@ -1136,11 +1136,9 @@ def nancumprod(x1, **kwargs):
11361136
11371137 """
11381138
1139- if not use_origin_backend (x1 ) and not kwargs :
1140- if not isinstance (x1 , dparray ):
1141- pass
1142- else :
1143- return dpnp_nancumprod (x1 )
1139+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
1140+ if x1_desc and not kwargs :
1141+ return dpnp_nancumprod (x1_desc )
11441142
11451143 return call_origin (numpy .nancumprod , x1 , ** kwargs )
11461144
@@ -1174,11 +1172,9 @@ def nancumsum(x1, **kwargs):
11741172
11751173 """
11761174
1177- if not use_origin_backend (x1 ) and not kwargs :
1178- if not isinstance (x1 , dparray ):
1179- pass
1180- else :
1181- return dpnp_nancumsum (x1 )
1175+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
1176+ if x1_desc and not kwargs :
1177+ return dpnp_nancumsum (x1_desc )
11821178
11831179 return call_origin (numpy .nancumsum , x1 , ** kwargs )
11841180
@@ -1206,9 +1202,8 @@ def nanprod(x1, **kwargs):
12061202
12071203 """
12081204
1209- is_x1_dparray = isinstance (x1 , dparray )
1210-
1211- if (not use_origin_backend (x1 ) and is_x1_dparray and not kwargs ):
1205+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
1206+ if x1_desc and not kwargs :
12121207 return dpnp_nanprod (x1 )
12131208
12141209 return call_origin (numpy .nanprod , x1 , ** kwargs )
@@ -1237,9 +1232,8 @@ def nansum(x1, **kwargs):
12371232
12381233 """
12391234
1240- is_x1_dparray = isinstance (x1 , dparray )
1241-
1242- if (not use_origin_backend (x1 ) and is_x1_dparray and not kwargs ):
1235+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
1236+ if x1_desc and not kwargs :
12431237 return dpnp_nansum (x1 )
12441238
12451239 return call_origin (numpy .nansum , x1 , ** kwargs )
@@ -1360,16 +1354,15 @@ def prod(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, wher
13601354
13611355 """
13621356
1363- if not use_origin_backend (x1 ):
1364- if not isinstance (x1 , dparray ):
1365- pass
1366- elif out is not None and not isinstance (out , dparray ):
1357+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
1358+ if x1_desc :
1359+ if out is not None and not isinstance (out , dparray ):
13671360 pass
13681361 elif where is not True :
13691362 pass
13701363 else :
1371- result_obj = dpnp_prod (x1 , axis , dtype , out , keepdims , initial , where )
1372- result = dpnp .convert_single_elem_array_to_scalar (result_obj , keepdims )
1364+ result_obj = dpnp_prod (x1_desc , axis , dtype , out , keepdims , initial , where )
1365+ result = dpnp .convert_single_elem_array_to_scalar (result_obj , keepdims )
13731366
13741367 return result
13751368
@@ -1540,23 +1533,22 @@ def sum(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where
15401533
15411534 """
15421535
1543- if not use_origin_backend (x1 ):
1544- if not isinstance (x1 , dparray ):
1545- pass
1546- elif out is not None and not isinstance (out , dparray ):
1536+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
1537+ if x1_desc :
1538+ if out is not None and not isinstance (out , dparray ):
15471539 pass
15481540 elif where is not True :
15491541 pass
15501542 else :
1551- result_obj = dpnp_sum (x1 , axis , dtype , out , keepdims , initial , where )
1543+ result_obj = dpnp_sum (x1_desc , axis , dtype , out , keepdims , initial , where )
15521544 result = dpnp .convert_single_elem_array_to_scalar (result_obj , keepdims )
15531545
15541546 return result
15551547
15561548 return call_origin (numpy .sum , x1 , axis = axis , dtype = dtype , out = out , keepdims = keepdims , initial = initial , where = where )
15571549
15581550
1559- def trapz (y , x = None , dx = 1.0 , ** kwargs ):
1551+ def trapz (y , x = None , dx = 1.0 , axis = - 1 ):
15601552 """
15611553 Integrate along the given axis using the composite trapezoidal rule.
15621554
@@ -1583,25 +1575,23 @@ def trapz(y, x=None, dx=1.0, **kwargs):
15831575
15841576 """
15851577
1586- if not use_origin_backend (y ):
1587-
1588- if not isinstance (y , dparray ):
1589- pass
1590- elif not isinstance (x , dparray ) and x is not None :
1578+ y_desc = dpnp .get_dpnp_descriptor (y )
1579+ if y_desc :
1580+ if not isinstance (x , dparray ) and x is not None :
15911581 pass
1592- elif x is not None and y .size != x .size :
1582+ elif x is not None and y_desc .size != x .size :
15931583 pass
1594- elif x is not None and y .shape != x .shape :
1584+ elif x is not None and y_desc .shape != x .shape :
15951585 pass
1596- elif y .ndim > 1 :
1586+ elif y_desc .ndim > 1 :
15971587 pass
15981588 else :
15991589 if x is None :
1600- x = dpnp .empty (0 , dtype = y .dtype )
1590+ x = dpnp .empty (0 , dtype = y_desc .dtype )
16011591
1602- return dpnp_trapz (y , x , dx )
1592+ return dpnp_trapz (y_desc , x , dx )
16031593
1604- return call_origin (numpy .trapz , y , x = x , dx = dx , ** kwargs )
1594+ return call_origin (numpy .trapz , y , x , dx , axis )
16051595
16061596
16071597def true_divide (* args , ** kwargs ):
0 commit comments