@@ -310,6 +310,32 @@ def sort(
310310 res = np .flip (res , axis = axis )
311311 return res
312312
313+ # sum() and prod() should always upcast when dtype=None
314+ def sum (
315+ x : ndarray ,
316+ / ,
317+ * ,
318+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
319+ dtype : Optional [Dtype ] = None ,
320+ keepdims : bool = False ,
321+ ) -> ndarray :
322+ # `np.sum` already upcasts integers, but not floats
323+ if dtype is None and x .dtype == np .float32 :
324+ dtype = np .float64
325+ return np .sum (x , axis = axis , dtype = dtype , keepdims = keepdims )
326+
327+ def prod (
328+ x : ndarray ,
329+ / ,
330+ * ,
331+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
332+ dtype : Optional [Dtype ] = None ,
333+ keepdims : bool = False ,
334+ ) -> ndarray :
335+ if dtype is None and x .dtype == np .float32 :
336+ dtype = np .float64
337+ return np .prod (x , dtype = dtype , axis = axis , keepdims = keepdims )
338+
313339# from numpy import * doesn't overwrite these builtin names
314340from numpy import abs , max , min , round
315341
@@ -321,4 +347,4 @@ def sort(
321347 'round' , 'std' , 'var' , 'permute_dims' , 'asarray' , 'arange' ,
322348 'empty' , 'empty_like' , 'eye' , 'full' , 'full_like' , 'linspace' ,
323349 'ones' , 'ones_like' , 'zeros' , 'zeros_like' , 'reshape' , 'argsort' ,
324- 'sort' ]
350+ 'sort' , 'sum' , 'prod' ]
0 commit comments