77from typing import TYPE_CHECKING
88if TYPE_CHECKING :
99 from typing import Optional , Tuple , Union
10- from numpy import ndarray , dtype
10+ from . _typing import ndarray , Device , Dtype , NestedSequence , SupportsBufferProtocol
1111
1212from typing import NamedTuple
1313
@@ -107,7 +107,7 @@ def unique_values(x: ndarray, /) -> ndarray:
107107 equal_nan = False ,
108108 )
109109
110- def astype (x : ndarray , dtype : dtype , / , * , copy : bool = True ) -> ndarray :
110+ def astype (x : ndarray , dtype : Dtype , / , * , copy : bool = True ) -> ndarray :
111111 if not copy and dtype == x .dtype :
112112 return x
113113 return x .astype (dtype = dtype , copy = copy )
@@ -138,6 +138,136 @@ def var(
138138def permute_dims (x : ndarray , / , axes : Tuple [int , ...]) -> ndarray :
139139 return np .transpose (x , axes )
140140
141+ # Creation functions add the device keyword (which does nothing for NumPy)
142+
143+ def _check_device (device ):
144+ if device not in ["cpu" , None ]:
145+ raise ValueError (f"Unsupported device { device !r} " )
146+
147+ def asarray (
148+ obj : Union [
149+ ndarray ,
150+ bool ,
151+ int ,
152+ float ,
153+ NestedSequence [bool | int | float ],
154+ SupportsBufferProtocol ,
155+ ],
156+ / ,
157+ * ,
158+ dtype : Optional [Dtype ] = None ,
159+ device : Optional [Device ] = None ,
160+ copy : Optional [Union [bool , np ._CopyMode ]] = None ,
161+ ) -> ndarray :
162+ _check_device (device )
163+ if copy in (False , np ._CopyMode .IF_NEEDED ):
164+ # copy=False is not yet implemented in np.asarray
165+ raise NotImplementedError ("copy=False is not yet implemented" )
166+ return np .asarray (obj , dtype = dtype )
167+
168+ def arange (
169+ start : Union [int , float ],
170+ / ,
171+ stop : Optional [Union [int , float ]] = None ,
172+ step : Union [int , float ] = 1 ,
173+ * ,
174+ dtype : Optional [Dtype ] = None ,
175+ device : Optional [Device ] = None ,
176+ ) -> ndarray :
177+ _check_device (device )
178+ return np .arange (start , stop = stop , step = step , dtype = dtype )
179+
180+ def empty (
181+ shape : Union [int , Tuple [int , ...]],
182+ * ,
183+ dtype : Optional [Dtype ] = None ,
184+ device : Optional [Device ] = None ,
185+ ) -> ndarray :
186+ _check_device (device )
187+ return np .empty (shape , dtype = dtype )
188+
189+ def empty_like (
190+ x : ndarray , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
191+ ) -> ndarray :
192+ _check_device (device )
193+ return np .empty_like (x , dtype = dtype )
194+
195+ def eye (
196+ n_rows : int ,
197+ n_cols : Optional [int ] = None ,
198+ / ,
199+ * ,
200+ k : int = 0 ,
201+ dtype : Optional [Dtype ] = None ,
202+ device : Optional [Device ] = None ,
203+ ) -> ndarray :
204+ _check_device (device )
205+ return np .eye (n_rows , M = n_cols , k = k , dtype = dtype )
206+
207+ def full (
208+ shape : Union [int , Tuple [int , ...]],
209+ fill_value : Union [int , float ],
210+ * ,
211+ dtype : Optional [Dtype ] = None ,
212+ device : Optional [Device ] = None ,
213+ ) -> ndarray :
214+ _check_device (device )
215+ return np .full (shape , fill_value , dtype = dtype )
216+
217+ def full_like (
218+ x : ndarray ,
219+ / ,
220+ fill_value : Union [int , float ],
221+ * ,
222+ dtype : Optional [Dtype ] = None ,
223+ device : Optional [Device ] = None ,
224+ ) -> ndarray :
225+ _check_device (device )
226+ return np .full_like (x , fill_value , dtype = dtype )
227+
228+ def linspace (
229+ start : Union [int , float ],
230+ stop : Union [int , float ],
231+ / ,
232+ num : int ,
233+ * ,
234+ dtype : Optional [Dtype ] = None ,
235+ device : Optional [Device ] = None ,
236+ endpoint : bool = True ,
237+ ) -> ndarray :
238+ _check_device (device )
239+ return np .linspace (start , stop , num , dtype = dtype , endpoint = endpoint )
240+
241+ def ones (
242+ shape : Union [int , Tuple [int , ...]],
243+ * ,
244+ dtype : Optional [Dtype ] = None ,
245+ device : Optional [Device ] = None ,
246+ ) -> ndarray :
247+ _check_device (device )
248+ return np .ones (shape , dtype = dtype )
249+
250+ def ones_like (
251+ x : ndarray , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
252+ ) -> ndarray :
253+ _check_device (device )
254+ return np .ones_like (x , dtype = dtype )
255+
256+ def zeros (
257+ shape : Union [int , Tuple [int , ...]],
258+ * ,
259+ dtype : Optional [Dtype ] = None ,
260+ device : Optional [Device ] = None ,
261+ ) -> ndarray :
262+ _check_device (device )
263+ return np .zeros (shape , dtype = dtype )
264+
265+ def zeros_like (
266+ x : ndarray , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
267+ ) -> ndarray :
268+ _check_device (device )
269+ return np .zeros_like (x , dtype = dtype )
270+
141271# from numpy import * doesn't overwrite these builtin names
142272from numpy import abs , max , min , round
143273
@@ -146,4 +276,6 @@ def permute_dims(x: ndarray, /, axes: Tuple[int, ...]) -> ndarray:
146276 'bool' , 'concat' , 'pow' , 'UniqueAllResult' , 'UniqueCountsResult' ,
147277 'UniqueInverseResult' , 'unique_all' , 'unique_counts' ,
148278 'unique_inverse' , 'unique_values' , 'astype' , 'abs' , 'max' , 'min' ,
149- 'round' , 'std' , 'var' , 'permute_dims' ]
279+ 'round' , 'std' , 'var' , 'permute_dims' , 'asarray' , 'arange' ,
280+ 'empty' , 'empty_like' , 'eye' , 'full' , 'full_like' , 'linspace' ,
281+ 'ones' , 'ones_like' , 'zeros' , 'zeros_like' ]
0 commit comments