@@ -13,8 +13,8 @@ class DistributedArray(np.ndarray):
1313 Shape of non-distributed global array
1414 subcomm : None, Subcomm instance or sequence of ints
1515 Describes how to distribute the array
16- val : int
17- Initialize array with this value if buffer is not given
16+ val : int or None
17+ Initialize array with this int if buffer is not given
1818 dtype : np.dtype
1919 Type of array
2020 buffer : np.ndarray
@@ -23,7 +23,7 @@ class DistributedArray(np.ndarray):
2323 Make sure array is aligned in this direction
2424
2525 """
26- def __new__ (cls , global_shape , subcomm = None , val = 0 , dtype = np .float ,
26+ def __new__ (cls , global_shape , subcomm = None , val = None , dtype = np .float ,
2727 buffer = None , alignment = None ):
2828 if isinstance (subcomm , Subcomm ):
2929 pass
@@ -49,7 +49,7 @@ def __new__(cls, global_shape, subcomm=None, val=0, dtype=np.float,
4949 alignment = np .flatnonzero (np .array (sizes ) == 1 )[- 1 ]
5050 p0 = Pencil (subcomm , global_shape , axis = alignment )
5151 obj = np .ndarray .__new__ (cls , p0 .subshape , dtype = dtype , buffer = buffer )
52- if buffer is None :
52+ if buffer is None and isinstance ( val , int ) :
5353 obj .fill (val )
5454 obj .p0 = p0
5555 obj .global_shape = global_shape
@@ -79,15 +79,14 @@ def redistribute(self, axis):
7979 Returns
8080 -------
8181 DistributedArray
82- self array globally redistributed along new axis
82+ New DistributedArray globally redistributed along axis
8383
8484 """
8585 p1 = self .p0 .pencil (axis )
8686 transfer = self .p0 .transfer (p1 , self .dtype )
87- z0 = np .zeros (p1 .subshape , dtype = self .dtype )
87+ z0 = DistributedArray (self .global_shape ,
88+ subcomm = p1 .subcomm ,
89+ dtype = self .dtype ,
90+ alignment = axis )
8891 transfer .forward (self , z0 )
89- return DistributedArray (self .global_shape ,
90- subcomm = p1 .subcomm ,
91- dtype = self .dtype ,
92- alignment = axis ,
93- buffer = z0 )
92+ return z0
0 commit comments