Skip to content

Commit 47344d7

Browse files
committed
Make returned DistributedArray own data
1 parent 1b60ad0 commit 47344d7

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

mpi4py_fft/distributedarray.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ class DistributedArray(np.ndarray):
1313
Shape of non-distributed global array
1414
subcomm : None, Subcomm instance or sequence of ints
1515
Describes how to distribute the array
16-
val : int
17-
Initialize array with this value if buffer is not given
16+
val : int or None
17+
Initialize array with this int if buffer is not given
1818
dtype : np.dtype
1919
Type of array
2020
buffer : np.ndarray
@@ -23,7 +23,7 @@ class DistributedArray(np.ndarray):
2323
Make sure array is aligned in this direction
2424
2525
"""
26-
def __new__(cls, global_shape, subcomm=None, val=0, dtype=np.float,
26+
def __new__(cls, global_shape, subcomm=None, val=None, dtype=np.float,
2727
buffer=None, alignment=None):
2828
if isinstance(subcomm, Subcomm):
2929
pass
@@ -49,7 +49,7 @@ def __new__(cls, global_shape, subcomm=None, val=0, dtype=np.float,
4949
alignment = np.flatnonzero(np.array(sizes) == 1)[-1]
5050
p0 = Pencil(subcomm, global_shape, axis=alignment)
5151
obj = np.ndarray.__new__(cls, p0.subshape, dtype=dtype, buffer=buffer)
52-
if buffer is None:
52+
if buffer is None and isinstance(val, int):
5353
obj.fill(val)
5454
obj.p0 = p0
5555
obj.global_shape = global_shape
@@ -79,15 +79,14 @@ def redistribute(self, axis):
7979
Returns
8080
-------
8181
DistributedArray
82-
self array globally redistributed along new axis
82+
New DistributedArray globally redistributed along axis
8383
8484
"""
8585
p1 = self.p0.pencil(axis)
8686
transfer = self.p0.transfer(p1, self.dtype)
87-
z0 = np.zeros(p1.subshape, dtype=self.dtype)
87+
z0 = DistributedArray(self.global_shape,
88+
subcomm=p1.subcomm,
89+
dtype=self.dtype,
90+
alignment=axis)
8891
transfer.forward(self, z0)
89-
return DistributedArray(self.global_shape,
90-
subcomm=p1.subcomm,
91-
dtype=self.dtype,
92-
alignment=axis,
93-
buffer=z0)
92+
return z0

0 commit comments

Comments
 (0)