|
40 | 40 | """ |
41 | 41 |
|
42 | 42 |
|
43 | | -import numpy |
44 | | - |
45 | 43 | from dpnp.dpnp_algo import * |
46 | 44 | from dpnp.dparray import dparray |
47 | 45 | from dpnp.dpnp_utils import * |
| 46 | + |
48 | 47 | import dpnp |
| 48 | +import numpy |
49 | 49 |
|
50 | 50 |
|
51 | 51 | __all__ = [ |
|
93 | 93 | ] |
94 | 94 |
|
95 | 95 |
|
96 | | -def convert_result_scalar(result, keepdims): |
97 | | - # one element array result should be converted into scalar |
98 | | - # TODO empty shape must be converted into scalar (it is not in test system) |
99 | | - if (len(result.shape) > 0) and (result.size == 1) and (keepdims is False): |
100 | | - return result.dtype.type(result[0]) |
101 | | - else: |
102 | | - return result |
103 | | - |
104 | | - |
105 | 96 | def abs(*args, **kwargs): |
106 | 97 | """ |
107 | 98 | Calculate the absolute value element-wise. |
@@ -1377,8 +1368,10 @@ def prod(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, wher |
1377 | 1368 | elif where is not True: |
1378 | 1369 | pass |
1379 | 1370 | else: |
1380 | | - result = dpnp_prod(x1, axis, dtype, out, keepdims, initial, where) |
1381 | | - return convert_result_scalar(result, keepdims) |
| 1371 | + result_obj = dpnp_prod(x1, axis, dtype, out, keepdims, initial, where) |
| 1372 | + result = dpnp.convert_single_elem_array_to_scalar(result_obj, keepdims) |
| 1373 | + |
| 1374 | + return result |
1382 | 1375 |
|
1383 | 1376 | return call_origin(numpy.prod, x1, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) |
1384 | 1377 |
|
@@ -1555,8 +1548,10 @@ def sum(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where |
1555 | 1548 | elif where is not True: |
1556 | 1549 | pass |
1557 | 1550 | else: |
1558 | | - result = dpnp_sum(x1, axis, dtype, out, keepdims, initial, where) |
1559 | | - return convert_result_scalar(result, keepdims) |
| 1551 | + result_obj = dpnp_sum(x1, axis, dtype, out, keepdims, initial, where) |
| 1552 | + result = dpnp.convert_single_elem_array_to_scalar(result_obj, keepdims) |
| 1553 | + |
| 1554 | + return result |
1560 | 1555 |
|
1561 | 1556 | return call_origin(numpy.sum, x1, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) |
1562 | 1557 |
|
|
0 commit comments