|
| 1 | +import numpy as np |
| 2 | +from mpi4py import MPI |
| 3 | +from .pencil import Pencil, Subcomm |
| 4 | + |
| 5 | +comm = MPI.COMM_WORLD |
| 6 | + |
| 7 | +class DistributedArray(np.ndarray): |
| 8 | + """Distributed Numpy array |
| 9 | +
|
| 10 | + Parameters |
| 11 | + ---------- |
| 12 | + global_shape : sequence of ints |
| 13 | + Shape of non-distributed global array |
| 14 | + subcomm : None, Subcomm instance or sequence of ints |
| 15 | + Describes how to distribute the array |
| 16 | + val : int |
| 17 | + Initialize array with this value if buffer is not given |
| 18 | + dtype : np.dtype |
| 19 | + Type of array |
| 20 | + buffer : np.ndarray |
| 21 | + Array of correct shape |
| 22 | + alignment : None or int |
| 23 | + Make sure array is aligned in this direction |
| 24 | +
|
| 25 | + """ |
| 26 | + def __new__(cls, global_shape, subcomm=None, val=0, dtype=np.float, |
| 27 | + buffer=None, alignment=None): |
| 28 | + if isinstance(subcomm, Subcomm): |
| 29 | + pass |
| 30 | + else: |
| 31 | + if isinstance(subcomm, (tuple, list)): |
| 32 | + assert len(subcomm) == len(global_shape) |
| 33 | + # Do nothing if already containing communicators. A tuple of subcommunicators is not necessarily a Subcomm |
| 34 | + if not np.all([isinstance(s, MPI.Cartcomm) for s in subcomm]): |
| 35 | + subcomm = Subcomm(comm, subcomm) |
| 36 | + else: |
| 37 | + assert subcomm is None |
| 38 | + if alignment is not None: |
| 39 | + subcomm = [0] * len(global_shape) |
| 40 | + subcomm[alignment] = 1 |
| 41 | + subcomm = Subcomm(comm, subcomm) |
| 42 | + |
| 43 | + sizes = [s.Get_size() for s in subcomm] |
| 44 | + if alignment is not None: |
| 45 | + assert isinstance(alignment, int) |
| 46 | + assert sizes[alignment] == 1 |
| 47 | + else: |
| 48 | + # Decide that alignment is the last axis with size 1 |
| 49 | + alignment = np.flatnonzero(np.array(sizes) == 1)[-1] |
| 50 | + p0 = Pencil(subcomm, global_shape, axis=alignment) |
| 51 | + obj = np.ndarray.__new__(cls, p0.subshape, dtype=dtype, buffer=buffer) |
| 52 | + if buffer is None: |
| 53 | + obj.fill(val) |
| 54 | + obj.p0 = p0 |
| 55 | + obj.global_shape = global_shape |
| 56 | + return obj |
| 57 | + |
| 58 | + def alignment(self): |
| 59 | + return self.p0.axis |
| 60 | + |
| 61 | + def __array_finalize__(self, obj): |
| 62 | + if obj is None: |
| 63 | + return |
| 64 | + self.p0 = getattr(obj, 'p0', None) |
| 65 | + self.global_shape = getattr(obj, 'global_shape', None) |
| 66 | + |
| 67 | + def redistribute(self, axis): |
| 68 | + """Global redistribution of array into alignment in ``axis`` |
| 69 | +
|
| 70 | + Parameters |
| 71 | + ---------- |
| 72 | + axis : int |
| 73 | + Align array along this axis |
| 74 | +
|
| 75 | + Returns |
| 76 | + ------- |
| 77 | + DistributedArray |
| 78 | + self array globally redistributed along new axis |
| 79 | +
|
| 80 | + """ |
| 81 | + p1 = self.p0.pencil(axis) |
| 82 | + transfer = self.p0.transfer(p1, self.dtype) |
| 83 | + z0 = np.zeros(p1.subshape, dtype=self.dtype) |
| 84 | + transfer.forward(self, z0) |
| 85 | + return DistributedArray(self.global_shape, |
| 86 | + subcomm=p1.subcomm, |
| 87 | + dtype=self.dtype, |
| 88 | + alignment=axis, |
| 89 | + buffer=z0) |
0 commit comments