@@ -144,6 +144,7 @@ def _check_device(device):
144144 if device not in ["cpu" , None ]:
145145 raise ValueError (f"Unsupported device { device !r} " )
146146
147+ # asarray also adds the copy keyword
147148def asarray (
148149 obj : Union [
149150 ndarray ,
@@ -336,6 +337,23 @@ def prod(
336337 dtype = np .float64
337338 return np .prod (x , dtype = dtype , axis = axis , keepdims = keepdims )
338339
340+ # ceil, floor, and trunc return integers for integer inputs
341+
342+ def ceil (x : ndarray , / ) -> ndarray :
343+ if np .issubdtype (x .dtype , np .integer ):
344+ return x
345+ return np .ceil (x )
346+
347+ def floor (x : ndarray , / ) -> ndarray :
348+ if np .issubdtype (x .dtype , np .integer ):
349+ return x
350+ return np .floor (x )
351+
352+ def trunc (x : ndarray , / ) -> ndarray :
353+ if np .issubdtype (x .dtype , np .integer ):
354+ return x
355+ return np .trunc (x )
356+
339357# from numpy import * doesn't overwrite these builtin names
340358from numpy import abs , max , min , round
341359
@@ -347,4 +365,4 @@ def prod(
347365 'round' , 'std' , 'var' , 'permute_dims' , 'asarray' , 'arange' ,
348366 'empty' , 'empty_like' , 'eye' , 'full' , 'full_like' , 'linspace' ,
349367 'ones' , 'ones_like' , 'zeros' , 'zeros_like' , 'reshape' , 'argsort' ,
350- 'sort' , 'sum' , 'prod' ]
368+ 'sort' , 'sum' , 'prod' , 'ceil' , 'floor' , 'trunc' ]
0 commit comments