@@ -273,72 +273,71 @@ def get_pencil_and_transfer(self, axis):
273273 p1 = self ._p0 .pencil (axis )
274274 return p1 , self ._p0 .transfer (p1 , self .dtype )
275275
276- def redistribute (self , axis = None , darray = None ):
276+ def redistribute (self , axis = None , out = None ):
277277 """Global redistribution of local ``self`` array
278278
279- Note
280- ----
281- Use either ``axis`` or ``darray``, not both.
282-
283279 Parameters
284280 ----------
285281 axis : int, optional
286282 Align local ``self`` array along this axis
287- darray : :class:`.DistArray`, optional
283+ out : :class:`.DistArray`, optional
288284 Copy data to this array of possibly different alignment
289285
290286 Returns
291287 -------
292- DistArray : darray
293- The ``self`` array globally redistributed. If keyword ``darray `` is
288+ DistArray : out
289+ The ``self`` array globally redistributed. If keyword ``out `` is
294290 None then a new DistArray (aligned along ``axis``) is created
295- and returned. Otherwise the provided darray is returned.
291+ and returned. Otherwise the provided out array is returned.
296292 """
297293 # Take care of some trivial cases first
298294 if axis == self .alignment :
299295 return self
300296
297+ if axis is not None and isinstance (out , DistArray ):
298+ assert axis == out .alignment
299+
301300 # Check if self is already aligned along axis. In that case just switch
302301 # axis of pencil (both axes are undivided) and return
303302 if axis is not None :
304303 if self .commsizes [self .rank + axis ] == 1 :
305304 self ._p0 .axis = axis
306305 return self
307306
308- if axis is None : # darray interface
309- assert isinstance (darray , np . ndarray )
310- assert self .global_shape == darray .global_shape
311- axis = darray .alignment
312- if self .commsizes == darray .commsizes :
307+ if out is not None :
308+ assert isinstance (out , DistArray )
309+ assert self .global_shape == out .global_shape
310+ axis = out .alignment
311+ if self .commsizes == out .commsizes :
313312 # Just a copy required. Should probably not be here
314- darray [:] = self
315- return darray
313+ out [:] = self
314+ return out
316315
317316 # Check that arrays are compatible
318317 for i in range (len (self ._p0 .shape )):
319- if i != self ._p0 .axis and i != darray ._p0 .axis :
320- assert self ._p0 .subcomm [i ] == darray ._p0 .subcomm [i ]
321- assert self ._p0 .subshape [i ] == darray ._p0 .subshape [i ]
318+ if i != self ._p0 .axis and i != out ._p0 .axis :
319+ assert self ._p0 .subcomm [i ] == out ._p0 .subcomm [i ]
320+ assert self ._p0 .subshape [i ] == out ._p0 .subshape [i ]
322321
323322 p1 , transfer = self .get_pencil_and_transfer (axis )
324- if darray is None :
325- darray = DistArray (self .global_shape ,
323+ if out is None :
324+ out = DistArray (self .global_shape ,
326325 subcomm = p1 .subcomm ,
327326 dtype = self .dtype ,
328327 alignment = axis ,
329328 rank = self .rank )
330329
331330 if self .rank == 0 :
332- transfer .forward (self , darray )
331+ transfer .forward (self , out )
333332 elif self .rank == 1 :
334333 for i in range (self .shape [0 ]):
335- transfer .forward (self [i ], darray [i ])
334+ transfer .forward (self [i ], out [i ])
336335 elif self .rank == 2 :
337336 for i in range (self .shape [0 ]):
338337 for j in range (self .shape [1 ]):
339- transfer .forward (self [i , j ], darray [i , j ])
338+ transfer .forward (self [i , j ], out [i , j ])
340339
341- return darray
340+ return out
342341
343342def newDistArray (pfft , forward_output = True , val = 0 , rank = 0 , view = False ):
344343 """Return a new :class:`.DistArray` object for provided :class:`.PFFT` object
0 commit comments