@@ -361,8 +361,10 @@ def std(x: array,
361361 # https://github.com/pytorch/pytorch/issues/61492. We don't try to
362362 # implement it here for now.
363363
364- # if isinstance(correction, float):
365- # correction = int(correction)
364+ if isinstance (correction , float ):
365+ _correction = int (correction )
366+ if correction != _correction :
367+ raise NotImplementedError ("float correction in torch std() is not yet supported" )
366368
367369 # https://github.com/pytorch/pytorch/issues/29137
368370 if axis == ():
@@ -372,10 +374,10 @@ def std(x: array,
372374 if axis is None :
373375 # torch doesn't support keepdims with axis=None
374376 # (https://github.com/pytorch/pytorch/issues/71209)
375- res = torch .std (x , tuple (range (x .ndim )), correction = correction , ** kwargs )
377+ res = torch .std (x , tuple (range (x .ndim )), correction = _correction , ** kwargs )
376378 res = _axis_none_keepdims (res , x .ndim , keepdims )
377379 return res
378- return torch .std (x , axis , correction = correction , keepdims = keepdims , ** kwargs )
380+ return torch .std (x , axis , correction = _correction , keepdims = keepdims , ** kwargs )
379381
380382def var (x : array ,
381383 / ,
@@ -519,6 +521,28 @@ def full(shape: Union[int, Tuple[int, ...]],
519521
520522 return torch .full (shape , fill_value , dtype = dtype , device = device , ** kwargs )
521523
524+ # ones, zeros, and empty do not accept shape as a keyword argument
525+ def ones (shape : Union [int , Tuple [int , ...]],
526+ * ,
527+ dtype : Optional [Dtype ] = None ,
528+ device : Optional [Device ] = None ,
529+ ** kwargs ) -> array :
530+ return torch .ones (shape , dtype = dtype , device = device , ** kwargs )
531+
532+ def zeros (shape : Union [int , Tuple [int , ...]],
533+ * ,
534+ dtype : Optional [Dtype ] = None ,
535+ device : Optional [Device ] = None ,
536+ ** kwargs ) -> array :
537+ return torch .zeros (shape , dtype = dtype , device = device , ** kwargs )
538+
539+ def empty (shape : Union [int , Tuple [int , ...]],
540+ * ,
541+ dtype : Optional [Dtype ] = None ,
542+ device : Optional [Device ] = None ,
543+ ** kwargs ) -> array :
544+ return torch .empty (shape , dtype = dtype , device = device , ** kwargs )
545+
522546# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
523547def expand_dims (x : array , / , * , axis : int = 0 ) -> array :
524548 return torch .unsqueeze (x , axis )
@@ -585,7 +609,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
585609 'logaddexp' , 'multiply' , 'not_equal' , 'pow' , 'remainder' ,
586610 'subtract' , 'max' , 'min' , 'sort' , 'prod' , 'sum' , 'any' , 'all' ,
587611 'mean' , 'std' , 'var' , 'concat' , 'squeeze' , 'flip' , 'roll' ,
588- 'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' ,
589- 'expand_dims ' , 'astype ' , 'broadcast_arrays ' , 'unique_all ' ,
590- 'unique_counts' , 'unique_inverse' , 'unique_values' ,
612+ 'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' , 'ones' ,
613+ 'zeros ' , 'empty ' , 'expand_dims ' , 'astype' , 'broadcast_arrays ' ,
614+ 'unique_all' , ' unique_counts' , 'unique_inverse' , 'unique_values' ,
591615 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' ]
0 commit comments