@@ -10,19 +10,19 @@ class DistArray(np.ndarray):
1010 """Distributed Numpy array
1111
1212 This Numpy array is part of a larger global array. Information about the
13- distribution is contained in the attributes
13+ distribution is contained in the attributes.
1414
1515 Parameters
1616 ----------
1717 global_shape : sequence of ints
1818 Shape of non-distributed global array
19- subcomm : None, Subcomm instance or sequence of ints, optional
19+ subcomm : None, :class:`. Subcomm` object or sequence of ints, optional
2020 Describes how to distribute the array
2121 val : Number or None, optional
2222 Initialize array with this number if buffer is not given
2323 dtype : np.dtype, optional
2424 Type of array
25- buffer : np.ndarray , optional
25+ buffer : Numpy array , optional
2626 Array of correct shape
2727 alignment : None or int, optional
2828 Make sure array is aligned in this direction. Note that alignment does
@@ -100,7 +100,13 @@ def __array_finalize__(self, obj):
100100
101101 @property
102102 def alignment (self ):
103- """Return alignment of local ``self`` array"""
103+ """Return alignment of local ``self`` array
104+
105+ Note
106+ ----
107+ For tensors of rank > 0 the array is actually aligned along
108+ ``alignment+rank``
109+ """
104110 return self ._p0 .axis
105111
106112 @property
@@ -130,7 +136,7 @@ def pencil(self):
130136
131137 @property
132138 def rank (self ):
133- """Return rank of ``self``"""
139+ """Return tensor rank of ``self``"""
134140 return self ._rank
135141
136142 def __getitem__ (self , i ):
@@ -191,9 +197,9 @@ def get_global_slice(self, gslice):
191197 s = self .local_slice ()
192198 sp = np .nonzero ([isinstance (x , slice ) for x in gslice ])[0 ]
193199 sf = tuple (np .take (s , sp ))
194- N = self .global_shape
195- f .require_dataset ('0' , shape = tuple (np .take (N , sp )), dtype = self .dtype )
200+ f .require_dataset ('data' , shape = tuple (np .take (self .global_shape , sp )), dtype = self .dtype )
196201 gslice = list (gslice )
202+ # We are required to check if the indices in si are on this processor
197203 si = np .nonzero ([isinstance (x , int ) and not z == slice (None ) for x , z in zip (gslice , s )])[0 ]
198204 on_this_proc = True
199205 for i in si :
@@ -202,12 +208,12 @@ def get_global_slice(self, gslice):
202208 else :
203209 on_this_proc = False
204210 if on_this_proc :
205- f ["0 " ][sf ] = self [tuple (gslice )]
211+ f ["data " ][sf ] = self [tuple (gslice )]
206212 f .close ()
207213 c = None
208214 if comm .Get_rank () == 0 :
209215 h = h5py .File ('tmp.h5' , 'r' )
210- c = h ['0 ' ].__array__ ()
216+ c = h ['data ' ].__array__ ()
211217 h .close ()
212218 os .remove ('tmp.h5' )
213219 return c
@@ -279,10 +285,10 @@ def redistribute(self, axis=None, darray=None):
279285
280286 Returns
281287 -------
282- :class:`. DistArray` : darray
288+ DistArray : darray
283289 The ``self`` array globally redistributed. If keyword ``darray`` is
284290 None then a new DistArray (aligned along ``axis``) is created
285- and returned
291+ and returned. Otherwise the provided darray is returned.
286292 """
287293 if axis is None :
288294 assert isinstance (darray , np .ndarray )
@@ -308,7 +314,7 @@ def redistribute(self, axis=None, darray=None):
308314 return darray
309315
310316def newDistArray (pfft , forward_output = True , val = 0 , rank = 0 , view = False ):
311- """Return a :class:`.DistArray` for provided :class:`.PFFT` object
317+ """Return a new :class:`.DistArray` object for provided :class:`.PFFT` object
312318
313319 Parameters
314320 ----------
@@ -317,15 +323,21 @@ def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False):
317323 If False then create newDistArray of shape/type for input to
318324 forward transform, otherwise create newDistArray of shape/type for
319325 output from forward transform.
320- val : int or float
326+ val : int or float, optional
321327 Value used to initialize array.
322- rank: int
328+ rank: int, optional
323329 Scalar has rank 0, vector 1 and matrix 2
324- view : bool
330+ view : bool, optional
325331 If True return view of the underlying Numpy array, i.e., return
326332 cls.view(np.ndarray). Note that the DistArray still will
327333 be accessible through the base attribute of the view.
328334
335+ Returns
336+ -------
337+ Distarray
338+ A new :class:`.DistArray` object. Return the ``ndarray`` view if
339+ keyword ``view`` is True.
340+
329341 Examples
330342 --------
331343 >>> from mpi4py import MPI
@@ -335,17 +347,13 @@ def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False):
335347 >>> u_hat = newDistArray(FFT, True, rank=1)
336348
337349 """
350+ global_shape = pfft .global_shape (forward_output )
351+ p0 = pfft .pencil [forward_output ]
338352 if forward_output is True :
339- shape = pfft .forward .output_array .shape
340353 dtype = pfft .forward .output_array .dtype
341- p0 = pfft .pencil [1 ]
342354 else :
343- shape = pfft .forward .input_array .shape
344355 dtype = pfft .forward .input_array .dtype
345- p0 = pfft .pencil [0 ]
346- commsizes = [s .Get_size () for s in p0 .subcomm ]
347- global_shape = tuple ([s * p for s , p in zip (shape , commsizes )])
348- global_shape = (len (shape ),)* rank + global_shape
356+ global_shape = (len (global_shape ),)* rank + global_shape
349357 z = DistArray (global_shape , subcomm = p0 .subcomm , val = val , dtype = dtype ,
350358 rank = rank )
351359 return z .v if view else z
0 commit comments