Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .._internal import get_xp
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
from ._typing import Array, DType, Namespace
from ._typing import Array, DType, JustFloat, JustInt, Namespace


# These are in the main NumPy namespace but not in numpy.linalg
Expand Down Expand Up @@ -139,7 +139,7 @@ def matrix_norm(
xp: Namespace,
*,
keepdims: bool = False,
ord: float | Literal["fro", "nuc"] | None = "fro",
ord: JustInt | JustFloat | Literal["fro", "nuc"] | None = "fro",
) -> Array:
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)

Expand All @@ -155,7 +155,7 @@ def vector_norm(
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
ord: float = 2,
ord: JustInt | JustFloat = 2,
) -> Array:
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
# when axis=None and the input is 2-D, so to force a vector norm, we make
Expand Down
44 changes: 43 additions & 1 deletion array_api_compat/common/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

from collections.abc import Mapping
from types import ModuleType as Namespace
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar
from typing import (
TYPE_CHECKING,
Literal,
Protocol,
TypeAlias,
TypedDict,
TypeVar,
final,
)

if TYPE_CHECKING:
from _typeshed import Incomplete
Expand All @@ -21,6 +29,37 @@
_T_co = TypeVar("_T_co", covariant=True)


# These "Just" types are equivalent to the `Just` type from the `optype` library,
# apart from them not being `@runtime_checkable`.
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
@final
class JustInt(Protocol):
@property
def __class__(self, /) -> type[int]: ...
@__class__.setter
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]


@final
class JustFloat(Protocol):
@property
def __class__(self, /) -> type[float]: ...
@__class__.setter
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]


@final
class JustComplex(Protocol):
@property
def __class__(self, /) -> type[complex]: ...
@__class__.setter
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]


#


class NestedSequence(Protocol[_T_co]):
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
def __len__(self, /) -> int: ...
Expand Down Expand Up @@ -140,6 +179,9 @@ class DTypesAll(DTypesBool, DTypesNumeric):
"Device",
"HasShape",
"Namespace",
"JustInt",
"JustFloat",
"JustComplex",
"NestedSequence",
"SupportsArrayNamespace",
"SupportsBufferProtocol",
Expand Down
8 changes: 6 additions & 2 deletions array_api_compat/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
from ._typing import Array, DType
from ..common._typing import JustInt, JustFloat

# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
Expand Down Expand Up @@ -84,8 +85,8 @@ def vector_norm(
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
# float stands for inf | -inf, which are not valid for Literal
ord: Union[int, float] = 2,
# JustFloat stands for inf | -inf, which are not valid for Literal
ord: JustInt | JustFloat = 2,
**kwargs,
) -> Array:
# torch.vector_norm incorrectly treats axis=() the same as axis=None
Expand Down Expand Up @@ -115,3 +116,6 @@ def vector_norm(
_all_ignore = ['torch_linalg', 'sum']

del linalg_all

def __dir__() -> list[str]:
return __all__
Loading