1+ # pyright: reportPrivateUsage=false
12from __future__ import annotations
23
3- from typing import Optional , Union
4+ from builtins import bool as py_bool
5+ from typing import TYPE_CHECKING , cast
6+
7+ import numpy as np
48
59from .._internal import get_xp
610from ..common import _aliases
711from ..common ._typing import NestedSequence , SupportsBufferProtocol
812from ._info import __array_namespace_info__
913from ._typing import Array , Device , DType
1014
11- import numpy as np
15+ if TYPE_CHECKING :
16+ from typing import Any , Literal , TypeAlias
17+
18+ from typing_extensions import Buffer , TypeIs
19+
20+ _Copy : TypeAlias = py_bool | Literal [2 ] | np ._CopyMode
1221
1322bool = np .bool_
1423
6372sign = get_xp (np )(_aliases .sign )
6473
6574
66- def _supports_buffer_protocol (obj ):
75+ def _supports_buffer_protocol (obj : object ) -> TypeIs [ Buffer ]: # pyright: ignore[reportUnusedFunction]
6776 try :
68- memoryview (obj )
77+ memoryview (obj ) # pyright: ignore[reportArgumentType]
6978 except TypeError :
7079 return False
7180 return True
@@ -76,15 +85,13 @@ def _supports_buffer_protocol(obj):
7685# complicated enough that it's easier to define it separately for each module
7786# rather than trying to combine everything into one function in common/
7887def asarray (
79- obj : (
80- Array | bool | complex | NestedSequence [bool | complex ] | SupportsBufferProtocol
81- ),
88+ obj : Array | complex | NestedSequence [complex ] | SupportsBufferProtocol ,
8289 / ,
8390 * ,
84- dtype : Optional [ DType ] = None ,
85- device : Optional [ Device ] = None ,
86- copy : "Optional[Union[bool, np._CopyMode]]" = None ,
87- ** kwargs ,
91+ dtype : DType | None = None ,
92+ device : Device | None = None ,
93+ copy : _Copy | None = None ,
94+ ** kwargs : Any ,
8895) -> Array :
8996 """
9097 Array API compatibility wrapper for asarray().
@@ -108,24 +115,28 @@ def asarray(
108115 if copy is False :
109116 raise NotImplementedError ("asarray(copy=False) requires a newer version of NumPy." )
110117
111- return np .array (obj , copy = copy , dtype = dtype , ** kwargs )
118+ return np .array (obj , copy = copy , dtype = dtype , ** kwargs ) # pyright: ignore
112119
113120
114121def astype (
115122 x : Array ,
116123 dtype : DType ,
117124 / ,
118125 * ,
119- copy : bool = True ,
120- device : Optional [ Device ] = None ,
126+ copy : py_bool = True ,
127+ device : Device | None = None ,
121128) -> Array :
122129 return x .astype (dtype = dtype , copy = copy )
123130
124131
125132# count_nonzero returns a python int for axis=None and keepdims=False
126133# https://github.com/numpy/numpy/issues/17562
127- def count_nonzero (x : Array , axis = None , keepdims = False ) -> Array :
128- result = np .count_nonzero (x , axis = axis , keepdims = keepdims )
134+ def count_nonzero (
135+ x : Array ,
136+ axis : int | tuple [int , ...] | None = None ,
137+ keepdims : py_bool = False ,
138+ ) -> Array :
139+ result = cast ("Any" , np .count_nonzero (x , axis = axis , keepdims = keepdims )) # pyright: ignore
129140 if axis is None and not keepdims :
130141 return np .asarray (result )
131142 return result
@@ -148,10 +159,25 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
148159else :
149160 unstack = get_xp (np )(_aliases .unstack )
150161
151- __all__ = _aliases .__all__ + ['__array_namespace_info__' , 'asarray' , 'astype' ,
152- 'acos' , 'acosh' , 'asin' , 'asinh' , 'atan' ,
153- 'atan2' , 'atanh' , 'bitwise_left_shift' ,
154- 'bitwise_invert' , 'bitwise_right_shift' ,
155- 'bool' , 'concat' , 'count_nonzero' , 'pow' ]
162+ __all__ = [
163+ "__array_namespace_info__" ,
164+ "asarray" ,
165+ "astype" ,
166+ "acos" ,
167+ "acosh" ,
168+ "asin" ,
169+ "asinh" ,
170+ "atan" ,
171+ "atan2" ,
172+ "atanh" ,
173+ "bitwise_left_shift" ,
174+ "bitwise_invert" ,
175+ "bitwise_right_shift" ,
176+ "bool" ,
177+ "concat" ,
178+ "count_nonzero" ,
179+ "pow" ,
180+ ]
181+ __all__ += _aliases .__all__
156182
157183_all_ignore = ['np' , 'get_xp' ]
0 commit comments