|
12 | 12 | from typing import NamedTuple |
13 | 13 | import inspect |
14 | 14 |
|
15 | | -from ._helpers import array_namespace, _check_device, device, is_cupy_namespace |
| 15 | +from ._helpers import ( |
| 16 | + array_namespace, |
| 17 | + _check_device, |
| 18 | + device as _get_device, |
| 19 | + is_cupy_namespace as _is_cupy_namespace |
| 20 | +) |
16 | 21 |
|
17 | 22 | # These functions are modified from the NumPy versions. |
18 | 23 |
|
@@ -287,7 +292,7 @@ def cumulative_sum( |
287 | 292 | initial_shape = list(x.shape) |
288 | 293 | initial_shape[axis] = 1 |
289 | 294 | res = xp.concatenate( |
290 | | - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res], |
| 295 | + [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], |
291 | 296 | axis=axis, |
292 | 297 | ) |
293 | 298 | return res |
@@ -317,7 +322,7 @@ def cumulative_prod( |
317 | 322 | initial_shape = list(x.shape) |
318 | 323 | initial_shape[axis] = 1 |
319 | 324 | res = xp.concatenate( |
320 | | - [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res], |
| 325 | + [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], |
321 | 326 | axis=axis, |
322 | 327 | ) |
323 | 328 | return res |
@@ -369,7 +374,7 @@ def _isscalar(a): |
369 | 374 | if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: |
370 | 375 | max = None |
371 | 376 |
|
372 | | - dev = device(x) |
| 377 | + dev = _get_device(x) |
373 | 378 | if out is None: |
374 | 379 | out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) |
375 | 380 | out[()] = x |
@@ -579,3 +584,5 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray: |
579 | 584 | 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', |
580 | 585 | 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', |
581 | 586 | 'unstack', 'sign'] |
| 587 | + |
| 588 | +_all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] |
0 commit comments