@@ -44,7 +44,7 @@ def cumulative_sum(
4444 if axis < 0 :
4545 axis += x .ndim
4646 x = concat ([zeros (x .shape [:axis ] + (1 ,) + x .shape [axis + 1 :], dtype = dt ), x ], axis = axis )
47- return Array ._new (np .cumsum (x ._array , axis = axis , dtype = dtype ))
47+ return Array ._new (np .cumsum (x ._array , axis = axis , dtype = dtype ), device = x . device )
4848
4949def max (
5050 x : Array ,
@@ -55,7 +55,7 @@ def max(
5555) -> Array :
5656 if x .dtype not in _real_numeric_dtypes :
5757 raise TypeError ("Only real numeric dtypes are allowed in max" )
58- return Array ._new (np .max (x ._array , axis = axis , keepdims = keepdims ))
58+ return Array ._new (np .max (x ._array , axis = axis , keepdims = keepdims ), device = x . device )
5959
6060
6161def mean (
@@ -67,7 +67,7 @@ def mean(
6767) -> Array :
6868 if x .dtype not in _real_floating_dtypes :
6969 raise TypeError ("Only real floating-point dtypes are allowed in mean" )
70- return Array ._new (np .mean (x ._array , axis = axis , keepdims = keepdims ))
70+ return Array ._new (np .mean (x ._array , axis = axis , keepdims = keepdims ), device = x . device )
7171
7272
7373def min (
@@ -79,7 +79,7 @@ def min(
7979) -> Array :
8080 if x .dtype not in _real_numeric_dtypes :
8181 raise TypeError ("Only real numeric dtypes are allowed in min" )
82- return Array ._new (np .min (x ._array , axis = axis , keepdims = keepdims ))
82+ return Array ._new (np .min (x ._array , axis = axis , keepdims = keepdims ), device = x . device )
8383
8484
8585def prod (
@@ -104,7 +104,7 @@ def prod(
104104 dtype = np .complex128
105105 else :
106106 dtype = dtype ._np_dtype
107- return Array ._new (np .prod (x ._array , dtype = dtype , axis = axis , keepdims = keepdims ))
107+ return Array ._new (np .prod (x ._array , dtype = dtype , axis = axis , keepdims = keepdims ), device = x . device )
108108
109109
110110def std (
@@ -118,7 +118,7 @@ def std(
118118 # Note: the keyword argument correction is different here
119119 if x .dtype not in _real_floating_dtypes :
120120 raise TypeError ("Only real floating-point dtypes are allowed in std" )
121- return Array ._new (np .std (x ._array , axis = axis , ddof = correction , keepdims = keepdims ))
121+ return Array ._new (np .std (x ._array , axis = axis , ddof = correction , keepdims = keepdims ), device = x . device )
122122
123123
124124def sum (
@@ -143,7 +143,7 @@ def sum(
143143 dtype = np .complex128
144144 else :
145145 dtype = dtype ._np_dtype
146- return Array ._new (np .sum (x ._array , axis = axis , dtype = dtype , keepdims = keepdims ))
146+ return Array ._new (np .sum (x ._array , axis = axis , dtype = dtype , keepdims = keepdims ), device = x . device )
147147
148148
149149def var (
@@ -157,4 +157,4 @@ def var(
157157 # Note: the keyword argument correction is different here
158158 if x .dtype not in _real_floating_dtypes :
159159 raise TypeError ("Only real floating-point dtypes are allowed in var" )
160- return Array ._new (np .var (x ._array , axis = axis , ddof = correction , keepdims = keepdims ))
160+ return Array ._new (np .var (x ._array , axis = axis , ddof = correction , keepdims = keepdims ), device = x . device )
0 commit comments