|
8 | 8 |
|
9 | 9 | from __future__ import annotations |
10 | 10 |
|
| 11 | +import enum |
11 | 12 | import inspect |
12 | 13 | import math |
13 | 14 | import sys |
@@ -485,6 +486,86 @@ def _check_api_version(api_version: str | None) -> None: |
485 | 486 | ) |
486 | 487 |
|
487 | 488 |
|
| 489 | +class _ClsToXPInfo(enum.Enum): |
| 490 | + SCALAR = 0 |
| 491 | + MAYBE_JAX_ZERO_GRADIENT = 1 |
| 492 | + |
| 493 | + |
| 494 | +@lru_cache(100) |
| 495 | +def _cls_to_namespace( |
| 496 | + cls: type, |
| 497 | + api_version: str | None, |
| 498 | + use_compat: bool | None, |
| 499 | +) -> tuple[Namespace | None, _ClsToXPInfo | None]: |
| 500 | + if use_compat not in (None, True, False): |
| 501 | + raise ValueError("use_compat must be None, True, or False") |
| 502 | + _use_compat = use_compat in (None, True) |
| 503 | + cls_ = cast(Hashable, cls) # Make mypy happy |
| 504 | + |
| 505 | + if ( |
| 506 | + _issubclass_fast(cls_, "numpy", "ndarray") |
| 507 | + or _issubclass_fast(cls_, "numpy", "generic") |
| 508 | + ): |
| 509 | + if use_compat is True: |
| 510 | + _check_api_version(api_version) |
| 511 | + from .. import numpy as xp |
| 512 | + elif use_compat is False: |
| 513 | + import numpy as xp # type: ignore[no-redef] |
| 514 | + else: |
| 515 | + # NumPy 2.0+ have __array_namespace__; however they are not |
| 516 | + # yet fully array API compatible. |
| 517 | + from .. import numpy as xp # type: ignore[no-redef] |
| 518 | + return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT |
| 519 | + |
| 520 | + # Note: this must happen _after_ the test for np.generic, |
| 521 | + # because np.float64 and np.complex128 are subclasses of float and complex. |
| 522 | + if issubclass(cls, int | float | complex | type(None)): |
| 523 | + return None, _ClsToXPInfo.SCALAR |
| 524 | + |
| 525 | + if _issubclass_fast(cls_, "cupy", "ndarray"): |
| 526 | + if _use_compat: |
| 527 | + _check_api_version(api_version) |
| 528 | + from .. import cupy as xp # type: ignore[no-redef] |
| 529 | + else: |
| 530 | + import cupy as xp # type: ignore[no-redef] |
| 531 | + return xp, None |
| 532 | + |
| 533 | + if _issubclass_fast(cls_, "torch", "Tensor"): |
| 534 | + if _use_compat: |
| 535 | + _check_api_version(api_version) |
| 536 | + from .. import torch as xp # type: ignore[no-redef] |
| 537 | + else: |
| 538 | + import torch as xp # type: ignore[no-redef] |
| 539 | + return xp, None |
| 540 | + |
| 541 | + if _issubclass_fast(cls_, "dask.array", "Array"): |
| 542 | + if _use_compat: |
| 543 | + _check_api_version(api_version) |
| 544 | + from ..dask import array as xp # type: ignore[no-redef] |
| 545 | + else: |
| 546 | + import dask.array as xp # type: ignore[no-redef] |
| 547 | + return xp, None |
| 548 | + |
| 549 | + # Backwards compatibility for jax<0.4.32 |
| 550 | + if _issubclass_fast(cls_, "jax", "Array"): |
| 551 | + return _jax_namespace(api_version, use_compat), None |
| 552 | + |
| 553 | + return None, None |
| 554 | + |
| 555 | + |
| 556 | +def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace: |
| 557 | + if use_compat: |
| 558 | + raise ValueError("JAX does not have an array-api-compat wrapper") |
| 559 | + import jax.numpy as jnp |
| 560 | + if not hasattr(jnp, "__array_namespace_info__"): |
| 561 | + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. |
| 562 | + # For older JAX versions, it is available via jax.experimental.array_api. |
| 563 | + # jnp.Array objects gain the __array_namespace__ method. |
| 564 | + import jax.experimental.array_api # noqa: F401 |
| 565 | + # Test api_version |
| 566 | + return jnp.empty(0).__array_namespace__(api_version=api_version) |
| 567 | + |
| 568 | + |
488 | 569 | def array_namespace( |
489 | 570 | *xs: Array | complex | None, |
490 | 571 | api_version: str | None = None, |
@@ -553,105 +634,40 @@ def your_function(x, y): |
553 | 634 | is_pydata_sparse_array |
554 | 635 |
|
555 | 636 | """ |
556 | | - if use_compat not in [None, True, False]: |
557 | | - raise ValueError("use_compat must be None, True, or False") |
558 | | - |
559 | | - _use_compat = use_compat in [None, True] |
560 | | - |
561 | 637 | namespaces: set[Namespace] = set() |
562 | 638 | for x in xs: |
563 | | - if is_numpy_array(x): |
564 | | - import numpy as np |
565 | | - |
566 | | - from .. import numpy as numpy_namespace |
567 | | - |
568 | | - if use_compat is True: |
569 | | - _check_api_version(api_version) |
570 | | - namespaces.add(numpy_namespace) |
571 | | - elif use_compat is False: |
572 | | - namespaces.add(np) |
573 | | - else: |
574 | | - # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API |
575 | | - # compatible. |
576 | | - namespaces.add(numpy_namespace) |
577 | | - elif is_cupy_array(x): |
578 | | - if _use_compat: |
579 | | - _check_api_version(api_version) |
580 | | - from .. import cupy as cupy_namespace |
581 | | - |
582 | | - namespaces.add(cupy_namespace) |
583 | | - else: |
584 | | - import cupy as cp # pyright: ignore[reportMissingTypeStubs] |
585 | | - |
586 | | - namespaces.add(cp) |
587 | | - elif is_torch_array(x): |
588 | | - if _use_compat: |
589 | | - _check_api_version(api_version) |
590 | | - from .. import torch as torch_namespace |
591 | | - |
592 | | - namespaces.add(torch_namespace) |
593 | | - else: |
594 | | - import torch |
595 | | - |
596 | | - namespaces.add(torch) |
597 | | - elif is_dask_array(x): |
598 | | - if _use_compat: |
599 | | - _check_api_version(api_version) |
600 | | - from ..dask import array as dask_namespace |
601 | | - |
602 | | - namespaces.add(dask_namespace) |
603 | | - else: |
604 | | - import dask.array as da |
605 | | - |
606 | | - namespaces.add(da) |
607 | | - elif is_jax_array(x): |
608 | | - if use_compat is True: |
609 | | - _check_api_version(api_version) |
610 | | - raise ValueError("JAX does not have an array-api-compat wrapper") |
611 | | - elif use_compat is False: |
612 | | - import jax.numpy as jnp |
613 | | - else: |
614 | | - # JAX v0.4.32 and newer implements the array API directly in jax.numpy. |
615 | | - # For older JAX versions, it is available via jax.experimental.array_api. |
616 | | - import jax.numpy |
617 | | - |
618 | | - if hasattr(jax.numpy, "__array_api_version__"): |
619 | | - jnp = jax.numpy |
620 | | - else: |
621 | | - import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] |
622 | | - namespaces.add(jnp) |
623 | | - elif is_pydata_sparse_array(x): |
624 | | - if use_compat is True: |
625 | | - _check_api_version(api_version) |
626 | | - raise ValueError("`sparse` does not have an array-api-compat wrapper") |
627 | | - else: |
628 | | - import sparse # pyright: ignore[reportMissingTypeStubs] |
629 | | - # `sparse` is already an array namespace. We do not have a wrapper |
630 | | - # submodule for it. |
631 | | - namespaces.add(sparse) |
632 | | - elif hasattr(x, "__array_namespace__"): |
633 | | - if use_compat is True: |
| 639 | + xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) |
| 640 | + if info is _ClsToXPInfo.SCALAR: |
| 641 | + continue |
| 642 | + |
| 643 | + if ( |
| 644 | + info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT |
| 645 | + and _is_jax_zero_gradient_array(x) |
| 646 | + ): |
| 647 | + xp = _jax_namespace(api_version, use_compat) |
| 648 | + |
| 649 | + if xp is None: |
| 650 | + get_ns = getattr(x, "__array_namespace__", None) |
| 651 | + if get_ns is None: |
| 652 | + raise TypeError(f"{type(x).__name__} is not a supported array type") |
| 653 | + if use_compat: |
634 | 654 | raise ValueError( |
635 | 655 | "The given array does not have an array-api-compat wrapper" |
636 | 656 | ) |
637 | | - x = cast("SupportsArrayNamespace[Any]", x) |
638 | | - namespaces.add(x.__array_namespace__(api_version=api_version)) |
639 | | - elif isinstance(x, (bool, int, float, complex, type(None))): |
640 | | - continue |
641 | | - else: |
642 | | - # TODO: Support Python scalars? |
643 | | - raise TypeError(f"{type(x).__name__} is not a supported array type") |
| 657 | + xp = get_ns(api_version=api_version) |
644 | 658 |
|
645 | | - if not namespaces: |
646 | | - raise TypeError("Unrecognized array input") |
| 659 | + namespaces.add(xp) |
647 | 660 |
|
648 | | - if len(namespaces) != 1: |
| 661 | + try: |
| 662 | + (xp,) = namespaces |
| 663 | + return xp |
| 664 | + except ValueError: |
| 665 | + if not namespaces: |
| 666 | + raise TypeError( |
| 667 | + "array_namespace requires at least one non-scalar array input" |
| 668 | + ) |
649 | 669 | raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") |
650 | 670 |
|
651 | | - (xp,) = namespaces |
652 | | - |
653 | | - return xp |
654 | | - |
655 | 671 |
|
656 | 672 | # backwards compatibility alias |
657 | 673 | get_namespace = array_namespace |
|
0 commit comments