|
47 | 47 | import dpnp |
48 | 48 |
|
49 | 49 | # pylint: disable=no-name-in-module |
50 | | -from .dpnp_algo import ( |
51 | | - dpnp_correlate, |
52 | | -) |
| 50 | +from .dpnp_algo import dpnp_correlate |
53 | 51 | from .dpnp_array import dpnp_array |
54 | | -from .dpnp_utils import ( |
55 | | - call_origin, |
56 | | - get_usm_allocations, |
57 | | -) |
| 52 | +from .dpnp_utils import call_origin, get_usm_allocations |
58 | 53 | from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call |
59 | | -from .dpnp_utils.dpnp_utils_statistics import ( |
60 | | - dpnp_cov, |
61 | | -) |
| 54 | +from .dpnp_utils.dpnp_utils_statistics import dpnp_cov |
62 | 55 |
|
63 | 56 | __all__ = [ |
64 | 57 | "amax", |
@@ -276,60 +269,61 @@ def average(a, axis=None, weights=None, returned=False, *, keepdims=False): |
276 | 269 | """ |
277 | 270 |
|
278 | 271 | dpnp.check_supported_arrays_type(a) |
| 272 | + usm_type, exec_q = get_usm_allocations([a, weights]) |
| 273 | + |
279 | 274 | if weights is None: |
280 | 275 | avg = dpnp.mean(a, axis=axis, keepdims=keepdims) |
281 | 276 | scl = dpnp.asanyarray( |
282 | 277 | avg.dtype.type(a.size / avg.size), |
283 | | - usm_type=a.usm_type, |
284 | | - sycl_queue=a.sycl_queue, |
| 278 | + usm_type=usm_type, |
| 279 | + sycl_queue=exec_q, |
285 | 280 | ) |
286 | 281 | else: |
287 | | - if not isinstance(weights, (dpnp_array, dpt.usm_ndarray)): |
288 | | - wgt = dpnp.asanyarray( |
289 | | - weights, usm_type=a.usm_type, sycl_queue=a.sycl_queue |
| 282 | + if not dpnp.is_supported_array_type(weights): |
| 283 | + weights = dpnp.asarray( |
| 284 | + weights, usm_type=usm_type, sycl_queue=exec_q |
290 | 285 | ) |
291 | | - else: |
292 | | - get_usm_allocations([a, weights]) |
293 | | - wgt = weights |
294 | 286 |
|
295 | | - if not dpnp.issubdtype(a.dtype, dpnp.inexact): |
| 287 | + a_dtype = a.dtype |
| 288 | + if not dpnp.issubdtype(a_dtype, dpnp.inexact): |
296 | 289 | default_dtype = dpnp.default_float_type(a.device) |
297 | | - result_dtype = dpnp.result_type(a.dtype, wgt.dtype, default_dtype) |
| 290 | + res_dtype = dpnp.result_type(a_dtype, weights.dtype, default_dtype) |
298 | 291 | else: |
299 | | - result_dtype = dpnp.result_type(a.dtype, wgt.dtype) |
| 292 | + res_dtype = dpnp.result_type(a_dtype, weights.dtype) |
300 | 293 |
|
301 | 294 | # Sanity checks |
302 | | - if a.shape != wgt.shape: |
| 295 | + wgt_shape = weights.shape |
| 296 | + a_shape = a.shape |
| 297 | + if a_shape != wgt_shape: |
303 | 298 | if axis is None: |
304 | 299 | raise TypeError( |
305 | 300 | "Axis must be specified when shapes of input array and " |
306 | 301 | "weights differ." |
307 | 302 | ) |
308 | | - if wgt.ndim != 1: |
| 303 | + if weights.ndim != 1: |
309 | 304 | raise TypeError( |
310 | 305 | "1D weights expected when shapes of input array and " |
311 | 306 | "weights differ." |
312 | 307 | ) |
313 | | - if wgt.shape[0] != a.shape[axis]: |
| 308 | + if wgt_shape[0] != a_shape[axis]: |
314 | 309 | raise ValueError( |
315 | 310 | "Length of weights not compatible with specified axis." |
316 | 311 | ) |
317 | 312 |
|
318 | | - # setup wgt to broadcast along axis |
319 | | - wgt = dpnp.broadcast_to(wgt, (a.ndim - 1) * (1,) + wgt.shape) |
320 | | - wgt = wgt.swapaxes(-1, axis) |
| 313 | + # setup weights to broadcast along axis |
| 314 | + weights = dpnp.broadcast_to( |
| 315 | + weights, (a.ndim - 1) * (1,) + wgt_shape |
| 316 | + ) |
| 317 | + weights = weights.swapaxes(-1, axis) |
321 | 318 |
|
322 | | - scl = wgt.sum(axis=axis, dtype=result_dtype, keepdims=keepdims) |
| 319 | + scl = weights.sum(axis=axis, dtype=res_dtype, keepdims=keepdims) |
323 | 320 | if dpnp.any(scl == 0.0): |
324 | 321 | raise ZeroDivisionError("Weights sum to zero, can't be normalized") |
325 | 322 |
|
326 | | - # result_datatype |
327 | | - avg = ( |
328 | | - dpnp.multiply(a, wgt).sum( |
329 | | - axis=axis, dtype=result_dtype, keepdims=keepdims |
330 | | - ) |
331 | | - / scl |
| 323 | + avg = dpnp.multiply(a, weights).sum( |
| 324 | + axis=axis, dtype=res_dtype, keepdims=keepdims |
332 | 325 | ) |
| 326 | + avg /= scl |
333 | 327 |
|
334 | 328 | if returned: |
335 | 329 | if scl.shape != avg.shape: |
@@ -556,7 +550,7 @@ def cov( |
556 | 550 |
|
557 | 551 | """ |
558 | 552 |
|
559 | | - if not isinstance(m, (dpnp_array, dpt.usm_ndarray)): |
| 553 | + if not dpnp.is_supported_array_type(m): |
560 | 554 | pass |
561 | 555 | elif m.ndim > 2: |
562 | 556 | pass |
|
0 commit comments