Skip to content

Commit 2dcd961

Browse files
authored
OSHMPI: Update after upstream fixes
Authored-by: Lisandro Dalcin <dalcinl@gmail.com>
1 parent e0ed87d commit 2dcd961

File tree

7 files changed

+99
-68
lines changed

7 files changed

+99
-68
lines changed

src/libshmem/config/oshmpi.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@
2020
#define PySHMEM_HAVE_shmem_alltoallsmem 1
2121
#define PySHMEM_HAVE_shmem_reduce 1
2222
#define PySHMEM_HAVE_shmem_wait_test_many 1
23-
/*#define PySHMEM_HAVE_shmem_pcontrol 1*/
23+
#define PySHMEM_HAVE_shmem_pcontrol 1
2424

2525
#endif

src/libshmem/fallback.h

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -532,15 +532,6 @@ PySHMEM_APPLY_STD_RMA_TYPES(PySHMEM_ALLTOALL)
532532

533533
/* --- */
534534

535-
#define PySHMEM_ALLTOALLS_BIT(N, dest, source, dst, sst, size, elsz) \
536-
do { \
537-
if ((elsz) == (N>>3)) { \
538-
shmem_alltoalls##N(dest, source, dst, sst, size, \
539-
0, 0, shmem_n_pes(), _py_shmem_pSync()) ; \
540-
return 0; \
541-
} \
542-
} while(0)
543-
544535
#if !defined(PySHMEM_HAVE_shmem_alltoallsmem)
545536

546537
static
@@ -561,26 +552,55 @@ int shmem_alltoallsmem(shmem_team_t team,
561552

562553
#endif
563554

555+
#if defined(PySHMEM_HAVE_shmem_alltoalls)
556+
557+
#define PySHMEM_ALLTOALLSMEM_X(N, team, dest, source, dst, sst, size, elsz) \
558+
do { \
559+
if ((elsz) % (N>>3) == 0) { \
560+
ptrdiff_t i, n = (ptrdiff_t) (elsz) / (N>>3); \
561+
for (i = 0; i < n; i++) { \
562+
uint##N##_t *d = (uint##N##_t*) (dest) + i; \
563+
const uint##N##_t *s = (uint##N##_t*) (source) + i; \
564+
int ierr = shmem_uint##N##_alltoalls((team), d, s, \
565+
(dst) * n, (sst) * n, size); \
566+
if (ierr) return ierr; \
567+
} \
568+
return 0; \
569+
} \
570+
} while(0) /**/
571+
572+
#else
573+
574+
#define PySHMEM_ALLTOALLSMEM_X(N, team, dest, source, dst, sst, size, elsz) \
575+
do { \
576+
if ((team) != SHMEM_TEAM_WORLD) return PySHMEM_UNAVAILABLE; \
577+
if ((elsz) % (N>>3) == 0) { \
578+
ptrdiff_t i, n = (ptrdiff_t) (elsz) / (N>>3); \
579+
for (i = 0; i < n; i++) { \
580+
uint##N##_t *d = (uint##N##_t*) (dest) + i; \
581+
const uint##N##_t *s = (const uint##N##_t*) (source) + i; \
582+
shmem_alltoalls##N(d, s, (dst) * n, (sst) * n, (size), \
583+
0, 0, shmem_n_pes(), _py_shmem_pSync()) ; \
584+
} \
585+
return 0; \
586+
} \
587+
} while(0) /**/
588+
589+
#endif
590+
564591
static
565592
int shmem_alltoallsmem_x(shmem_team_t team,
566593
void *dest, const void *source,
567594
ptrdiff_t dst, ptrdiff_t sst,
568595
size_t size, size_t eltsize)
569596
{
597+
PySHMEM_ALLTOALLSMEM_X(64, team, dest, source, dst, sst, size, eltsize);
598+
PySHMEM_ALLTOALLSMEM_X(32, team, dest, source, dst, sst, size, eltsize);
570599
#if defined(PySHMEM_HAVE_shmem_alltoalls)
571-
switch (eltsize) {
572-
case (1): return shmem_uint8_alltoalls (team, (uint8_t*) dest, (uint8_t*) source, dst, sst, size);
573-
case (2): return shmem_uint16_alltoalls(team, (uint16_t*) dest, (uint16_t*) source, dst, sst, size);
574-
case (4): return shmem_uint32_alltoalls(team, (uint32_t*) dest, (uint32_t*) source, dst, sst, size);
575-
case (8): return shmem_uint64_alltoalls(team, (uint64_t*) dest, (uint64_t*) source, dst, sst, size);
576-
}
577-
return PySHMEM_UNAVAILABLE;
578-
#else
579-
if (team != SHMEM_TEAM_WORLD) return PySHMEM_UNAVAILABLE;
580-
PySHMEM_ALLTOALLS_BIT(64, dest, source, dst, sst, size, eltsize);
581-
PySHMEM_ALLTOALLS_BIT(32, dest, source, dst, sst, size, eltsize);
582-
return PySHMEM_UNAVAILABLE;
600+
PySHMEM_ALLTOALLSMEM_X(16, team, dest, source, dst, sst, size, eltsize);
601+
PySHMEM_ALLTOALLSMEM_X(8 , team, dest, source, dst, sst, size, eltsize);
583602
#endif
603+
return PySHMEM_UNAVAILABLE;
584604
}
585605

586606
#if !defined(PySHMEM_HAVE_shmem_alltoalls)

test/test_coll.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ def testAllToAllStride(self):
172172
with self.subTest(type=t):
173173
itemsize = np.dtype(t).itemsize
174174
if not shmem_15:
175-
if itemsize not in (4, 8): continue
176-
if itemsize > 8: continue
175+
if itemsize < 4: continue
177176
tst, sst = 3, 5
178177
tgt = shmem.empty((npes, tst), dtype=t)
179178
src = shmem.empty((npes, sst), dtype=t)
@@ -196,8 +195,7 @@ def testAllToAllStrideSize(self):
196195
with self.subTest(type=t):
197196
itemsize = np.dtype(t).itemsize
198197
if not shmem_15:
199-
if itemsize not in (4, 8): continue
200-
if itemsize > 8: continue
198+
if itemsize < 4: continue
201199
tst, sst = 3, 5
202200
tgt = shmem.empty((3, npes, tst), dtype=t)
203201
src = shmem.empty((5, npes, sst), dtype=t)
@@ -214,22 +212,6 @@ def testAllToAllStrideSize(self):
214212
shmem.free(tgt)
215213
shmem.free(src)
216214

217-
def testAllToAllStrideUnsupported(self):
218-
mype = shmem.my_pe()
219-
npes = shmem.n_pes()
220-
t = 'D'
221-
tst, sst = 3, 5
222-
tgt = shmem.empty((3, npes, tst), dtype=t)
223-
src = shmem.empty((5, npes, sst), dtype=t)
224-
tgt[...] = npes
225-
src[...] = -1
226-
src[0, :, 0] = mype
227-
shmem.barrier_all()
228-
self.assertRaises(
229-
NotImplementedError, shmem.alltoalls,
230-
tgt, src, tst=tst, sst=sst, size=1,
231-
)
232-
233215

234216
if __name__ == '__main__':
235217
unittest.main()

test/test_ctx.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,36 +49,53 @@ def testWith(self):
4949
with ctx as alias:
5050
self.assertTrue(ctx is alias)
5151

52-
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
5352
def testWithNew(self):
5453
for ctx in ctxs:
5554
if not ctx: continue
56-
with ctx.create() as newctx:
55+
try:
56+
newctx = shmem.Ctx.create()
57+
except RuntimeError:
58+
continue
59+
with newctx:
5760
self.assertNotEqual(newctx, ctx)
5861
self.assertNotEqual(newctx, shmem.CTX_INVALID)
5962
self.assertEqual(newctx, shmem.CTX_INVALID)
6063

61-
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
6264
def testCreate(self):
63-
ctx = shmem.CTX_DEFAULT.create()
64-
self.assertNotEqual(ctx, shmem.CTX_DEFAULT)
65-
ctx.destroy()
66-
self.assertEqual(ctx, shmem.CTX_INVALID)
67-
ctx = shmem.CTX_DEFAULT.create(team=shmem.TEAM_WORLD)
68-
self.assertNotEqual(ctx, shmem.CTX_DEFAULT)
69-
ctx.destroy()
70-
self.assertEqual(ctx, shmem.CTX_INVALID)
65+
try:
66+
ctx = shmem.Ctx.create()
67+
except RuntimeError:
68+
pass
69+
else:
70+
self.assertNotEqual(ctx, shmem.CTX_DEFAULT)
71+
ctx.destroy()
72+
self.assertEqual(ctx, shmem.CTX_INVALID)
73+
try:
74+
ctx = shmem.Ctx.create(team=shmem.TEAM_WORLD)
75+
except RuntimeError:
76+
pass
77+
else:
78+
self.assertNotEqual(ctx, shmem.CTX_DEFAULT)
79+
ctx.destroy()
80+
self.assertEqual(ctx, shmem.CTX_INVALID)
7181

72-
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
7382
@unittest.skipIf('open-mpi' in shmem.VENDOR_STRING, 'open-mpi')
7483
def testCreateOptions(self):
84+
def create(*args, **kwargs):
85+
try:
86+
ctx = shmem.Ctx.create(*args, **kwargs)
87+
except RuntimeError:
88+
pass
89+
else:
90+
self.assertNotEqual(ctx, shmem.CTX_INVALID)
91+
self.assertNotEqual(ctx, shmem.CTX_DEFAULT)
92+
ctx.destroy()
93+
self.assertEqual(ctx, shmem.CTX_INVALID)
94+
7595
for opt in options:
76-
ctx = shmem.CTX_DEFAULT.create(opt)
77-
ctx.destroy()
78-
ctx = shmem.CTX_DEFAULT.create(opt, shmem.TEAM_WORLD)
79-
ctx.destroy()
80-
ctx = shmem.CTX_DEFAULT.create(options=opt, team=shmem.TEAM_WORLD)
81-
ctx.destroy()
96+
create(opt)
97+
create(opt, shmem.TEAM_WORLD)
98+
create(options=opt, team=shmem.TEAM_WORLD)
8299

83100
def testDestroy(self):
84101
for ctx in ctxs:
@@ -97,7 +114,6 @@ def testDestroyAlias(self):
97114
alias.destroy()
98115
self.assertFalse(alias)
99116

100-
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
101117
def testGetTeam(self):
102118
ctx = shmem.CTX_DEFAULT
103119
team = ctx.get_team()
@@ -114,7 +130,6 @@ def testQuiet(self):
114130
ctx.quiet()
115131
shmem.quiet(ctx)
116132

117-
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
118133
def testInvalid(self):
119134
ctx = shmem.CTX_INVALID
120135
self.assertRaises(RuntimeError, ctx.get_team)

test/test_signal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
shmem.SIGNAL_ADD: 'add',
1414
}
1515

16+
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
1617
@unittest.skipIf(shmem.SIGNAL_SET == shmem.SIGNAL_ADD, 'put-with-signal')
1718
class TestSignal(unittest.TestCase):
1819

test/test_sync.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def testOne(self):
3131
shmem.wait_until(ivar[..., pe], cmp, val)
3232
shmem.free(ivar)
3333

34+
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
3435
@unittest.skipIf('osss-ucx' in shmem.VENDOR_STRING, 'osss-ucx')
3536
def testAll(self):
3637
mype = shmem.my_pe()
@@ -96,6 +97,7 @@ def testAll(self):
9697
shmem.wait_until_all_vector(ivars, cmp, v2vec)
9798
shmem.free(ivars)
9899

100+
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
99101
@unittest.skipIf('osss-ucx' in shmem.VENDOR_STRING, 'osss-ucx')
100102
def testAny(self):
101103
mype = shmem.my_pe()
@@ -181,6 +183,7 @@ def testAny(self):
181183
self.assertNotEqual(index, None)
182184
shmem.free(ivars)
183185

186+
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
184187
@unittest.skipIf('osss-ucx' in shmem.VENDOR_STRING, 'osss-ucx')
185188
def testSome(self):
186189
mype = shmem.my_pe()
@@ -299,6 +302,7 @@ def testOne(self):
299302
self.assertTrue(flag)
300303
shmem.free(ivar)
301304

305+
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
302306
@unittest.skipIf('osss-ucx' in shmem.VENDOR_STRING, 'osss-ucx')
303307
def testAll(self):
304308
mype = shmem.my_pe()
@@ -374,6 +378,7 @@ def testAll(self):
374378
self.assertTrue(flag)
375379
shmem.free(ivars)
376380

381+
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
377382
@unittest.skipIf('osss-ucx' in shmem.VENDOR_STRING, 'osss-ucx')
378383
def testAny(self):
379384
mype = shmem.my_pe()
@@ -457,6 +462,7 @@ def testAny(self):
457462
self.assertNotEqual(index, None)
458463
shmem.free(ivars)
459464

465+
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
460466
@unittest.skipIf('osss-ucx' in shmem.VENDOR_STRING, 'osss-ucx')
461467
def testSome(self):
462468
mype = shmem.my_pe()
@@ -541,6 +547,7 @@ def testSome(self):
541547
shmem.free(ivars)
542548

543549

550+
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
544551
class TestSignal(unittest.TestCase):
545552

546553
def testWait(self):

test/test_team.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,22 @@ def testTranslate(self):
120120
tpe = team.translate_pe(pe, team=team)
121121
self.assertEqual(tpe, pe)
122122

123-
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
124123
def testCreateCtx(self):
125124
team = shmem.TEAM_WORLD
126-
ctx = team.create_ctx()
127-
ctx.destroy()
128-
for opt in options_ctx:
129-
ctx= team.create_ctx(opt)
125+
try:
126+
ctx = team.create_ctx()
127+
except RuntimeError:
128+
pass
129+
else:
130130
ctx.destroy()
131+
for opt in options_ctx:
132+
try:
133+
ctx= team.create_ctx(opt)
134+
except RuntimeError:
135+
pass
136+
else:
137+
ctx.destroy()
131138

132-
@unittest.skipIf('OSHMPI' in shmem.VENDOR_STRING, 'OSHMPI')
133139
def testSync(self):
134140
team = shmem.TEAM_WORLD
135141
team.sync()

0 commit comments

Comments
 (0)