@@ -162,13 +162,10 @@ def average(a, axis=None, weights=None, returned=False):
162162 elif returned :
163163 pass
164164 else :
165- array_avg = dpnp_average (a )
165+ result_obj = dpnp_average (a )
166+ result = dpnp .convert_single_elem_array_to_scalar (result_obj )
166167
167- # scalar returned
168- if array_avg .shape == (1 ,):
169- return array_avg .dtype .type (array_avg [0 ])
170-
171- return array_avg
168+ return result
172169
173170 return call_origin (numpy .average , a , axis , weights , returned )
174171
@@ -335,11 +332,8 @@ def max(input, axis=None, out=None, keepdims=numpy._NoValue, initial=numpy._NoVa
335332 elif where is not numpy ._NoValue :
336333 pass
337334 else :
338- result = dpnp_max (input , axis = axis )
339-
340- # scalar returned
341- if result .shape == (1 ,):
342- return result .dtype .type (result [0 ])
335+ result_obj = dpnp_max (input , axis = axis )
336+ result = dpnp .convert_single_elem_array_to_scalar (result_obj )
343337
344338 return result
345339
@@ -386,11 +380,8 @@ def mean(a, axis=None, **kwargs):
386380 elif a .size == 0 :
387381 pass
388382 else :
389- result = dpnp_mean (a , axis = axis )
390-
391- # scalar returned
392- if result .shape == (1 ,):
393- return result .dtype .type (result [0 ])
383+ result_obj = dpnp_mean (a , axis = axis )
384+ result = dpnp .convert_single_elem_array_to_scalar (result_obj )
394385
395386 return result
396387
@@ -439,11 +430,8 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
439430 elif keepdims :
440431 pass
441432 else :
442- result = dpnp_median (a )
443-
444- # scalar returned
445- if result .shape == (1 ,):
446- return result .dtype .type (result [0 ])
433+ result_obj = dpnp_median (a )
434+ result = dpnp .convert_single_elem_array_to_scalar (result_obj )
447435
448436 return result
449437
@@ -486,11 +474,8 @@ def min(input, axis=None, out=None, keepdims=numpy._NoValue, initial=numpy._NoVa
486474 elif where is not numpy ._NoValue :
487475 pass
488476 else :
489- result = dpnp_min (input , axis = axis )
490-
491- # scalar returned
492- if result .shape == (1 ,):
493- return result .dtype .type (result [0 ])
477+ result_obj = dpnp_min (input , axis = axis )
478+ result = dpnp .convert_single_elem_array_to_scalar (result_obj )
494479
495480 return result
496481
@@ -524,11 +509,8 @@ def nanvar(arr, axis=None, dtype=None, out=None, ddof=0, keepdims=numpy._NoValue
524509 elif keepdims is not numpy ._NoValue :
525510 pass
526511 else :
527- result = dpnp_nanvar (arr , ddof )
528-
529- # scalar returned
530- if result .shape == (1 ,):
531- return result .dtype .type (result [0 ])
512+ result_obj = dpnp_nanvar (arr , ddof )
513+ result = dpnp .convert_single_elem_array_to_scalar (result_obj )
532514
533515 return result
534516
@@ -586,9 +568,8 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=numpy._NoValue):
586568 elif keepdims is not numpy ._NoValue :
587569 pass
588570 else :
589- result = dpnp_std (a , ddof )
590- if result .shape == (1 ,):
591- return result .dtype .type (result [0 ])
571+ result_obj = dpnp_std (a , ddof )
572+ result = dpnp .convert_single_elem_array_to_scalar (result_obj )
592573
593574 return result
594575
@@ -646,9 +627,8 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=numpy._NoValue):
646627 elif keepdims is not numpy ._NoValue :
647628 pass
648629 else :
649- result = dpnp_var (a , ddof )
650- if result .shape == (1 ,):
651- return result .dtype .type (result [0 ])
630+ result_obj = dpnp_var (a , ddof )
631+ result = dpnp .convert_single_elem_array_to_scalar (result_obj )
652632
653633 return result
654634
0 commit comments