Skip to content

Commit cd3ddd1

Browse files
committed
Fixing newDarray with rank
1 parent 60d5d4b commit cd3ddd1

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

mpi4py_fft/distributedarray.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -302,14 +302,7 @@ def newDarray(pfft, forward_output=True, val=0, rank=0):
302302
p0 = pfft.pencil[0]
303303
commsizes = [s.Get_size() for s in p0.subcomm]
304304
global_shape = tuple([s*p for s, p in zip(shape, commsizes)])
305-
306-
if rank == 1:
307-
global_shape = (len(shape),) + global_shape
308-
elif rank == 2:
309-
global_shape = (len(shape), len(shape)) + global_shape
310-
else:
311-
assert rank == 0
312-
305+
global_shape = (len(shape),)*rank + global_shape
313306
return DistributedArray(global_shape, subcomm=p0.subcomm, val=val,
314307
dtype=dtype, rank=rank)
315308

0 commit comments

Comments
 (0)