|
19 | 19 | import dpctl |
20 | 20 | import dpctl.tensor as dpt |
21 | 21 | import dpctl.tensor._tensor_impl as ti |
| 22 | +import dpctl.tensor._tensor_reductions_impl as tri |
22 | 23 |
|
23 | 24 | from ._type_utils import _to_device_supported_dtype |
24 | 25 |
|
@@ -220,8 +221,8 @@ def sum(x, axis=None, dtype=None, keepdims=False): |
220 | 221 | axis, |
221 | 222 | dtype, |
222 | 223 | keepdims, |
223 | | - ti._sum_over_axis, |
224 | | - ti._sum_over_axis_dtype_supported, |
| 224 | + tri._sum_over_axis, |
| 225 | + tri._sum_over_axis_dtype_supported, |
225 | 226 | _default_reduction_dtype, |
226 | 227 | _identity=0, |
227 | 228 | ) |
@@ -281,8 +282,8 @@ def prod(x, axis=None, dtype=None, keepdims=False): |
281 | 282 | axis, |
282 | 283 | dtype, |
283 | 284 | keepdims, |
284 | | - ti._prod_over_axis, |
285 | | - ti._prod_over_axis_dtype_supported, |
| 285 | + tri._prod_over_axis, |
| 286 | + tri._prod_over_axis_dtype_supported, |
286 | 287 | _default_reduction_dtype, |
287 | 288 | _identity=1, |
288 | 289 | ) |
@@ -335,8 +336,8 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False): |
335 | 336 | axis, |
336 | 337 | dtype, |
337 | 338 | keepdims, |
338 | | - ti._logsumexp_over_axis, |
339 | | - lambda inp_dt, res_dt, *_: ti._logsumexp_over_axis_dtype_supported( |
| 339 | + tri._logsumexp_over_axis, |
| 340 | + lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported( |
340 | 341 | inp_dt, res_dt |
341 | 342 | ), |
342 | 343 | _default_reduction_dtype_fp_types, |
@@ -391,8 +392,8 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False): |
391 | 392 | axis, |
392 | 393 | dtype, |
393 | 394 | keepdims, |
394 | | - ti._hypot_over_axis, |
395 | | - lambda inp_dt, res_dt, *_: ti._hypot_over_axis_dtype_supported( |
| 395 | + tri._hypot_over_axis, |
| 396 | + lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported( |
396 | 397 | inp_dt, res_dt |
397 | 398 | ), |
398 | 399 | _default_reduction_dtype_fp_types, |
@@ -468,7 +469,7 @@ def max(x, axis=None, keepdims=False): |
468 | 469 | entire array, a zero-dimensional array is returned. The returned |
469 | 470 | array has the same data type as `x`. |
470 | 471 | """ |
471 | | - return _comparison_over_axis(x, axis, keepdims, ti._max_over_axis) |
| 472 | + return _comparison_over_axis(x, axis, keepdims, tri._max_over_axis) |
472 | 473 |
|
473 | 474 |
|
474 | 475 | def min(x, axis=None, keepdims=False): |
@@ -496,7 +497,7 @@ def min(x, axis=None, keepdims=False): |
496 | 497 | entire array, a zero-dimensional array is returned. The returned |
497 | 498 | array has the same data type as `x`. |
498 | 499 | """ |
499 | | - return _comparison_over_axis(x, axis, keepdims, ti._min_over_axis) |
| 500 | + return _comparison_over_axis(x, axis, keepdims, tri._min_over_axis) |
500 | 501 |
|
501 | 502 |
|
502 | 503 | def _search_over_axis(x, axis, keepdims, _reduction_fn): |
@@ -577,7 +578,7 @@ def argmax(x, axis=None, keepdims=False): |
577 | 578 | zero-dimensional array is returned. The returned array has the |
578 | 579 | default array index data type for the device of `x`. |
579 | 580 | """ |
580 | | - return _search_over_axis(x, axis, keepdims, ti._argmax_over_axis) |
| 581 | + return _search_over_axis(x, axis, keepdims, tri._argmax_over_axis) |
581 | 582 |
|
582 | 583 |
|
583 | 584 | def argmin(x, axis=None, keepdims=False): |
@@ -609,4 +610,4 @@ def argmin(x, axis=None, keepdims=False): |
609 | 610 | zero-dimensional array is returned. The returned array has the |
610 | 611 | default array index data type for the device of `x`. |
611 | 612 | """ |
612 | | - return _search_over_axis(x, axis, keepdims, ti._argmin_over_axis) |
| 613 | + return _search_over_axis(x, axis, keepdims, tri._argmin_over_axis) |
0 commit comments