@@ -532,15 +532,6 @@ PySHMEM_APPLY_STD_RMA_TYPES(PySHMEM_ALLTOALL)
532532
533533/* --- */
534534
535- #define PySHMEM_ALLTOALLS_BIT (N , dest , source , dst , sst , size , elsz ) \
536- do { \
537- if ((elsz) == (N>>3)) { \
538- shmem_alltoalls##N(dest, source, dst, sst, size, \
539- 0, 0, shmem_n_pes(), _py_shmem_pSync()) ; \
540- return 0; \
541- } \
542- } while(0)
543-
544535#if !defined(PySHMEM_HAVE_shmem_alltoallsmem )
545536
546537static
@@ -561,26 +552,55 @@ int shmem_alltoallsmem(shmem_team_t team,
561552
562553#endif
563554
555+ #if defined(PySHMEM_HAVE_shmem_alltoalls )
556+
557+ #define PySHMEM_ALLTOALLSMEM_X (N , team , dest , source , dst , sst , size , elsz ) \
558+ do { \
559+ if ((elsz) % (N>>3) == 0) { \
560+ ptrdiff_t i, n = (ptrdiff_t) (elsz) / (N>>3); \
561+ for (i = 0; i < n; i++) { \
562+ uint##N##_t *d = (uint##N##_t*) (dest) + i; \
563+ const uint##N##_t *s = (uint##N##_t*) (source) + i; \
564+ int ierr = shmem_uint##N##_alltoalls((team), d, s, \
565+ (dst) * n, (sst) * n, size); \
566+ if (ierr) return ierr; \
567+ } \
568+ return 0; \
569+ } \
570+ } while(0) /**/
571+
572+ #else
573+
574+ #define PySHMEM_ALLTOALLSMEM_X (N , team , dest , source , dst , sst , size , elsz ) \
575+ do { \
576+ if ((team) != SHMEM_TEAM_WORLD) return PySHMEM_UNAVAILABLE; \
577+ if ((elsz) % (N>>3) == 0) { \
578+ ptrdiff_t i, n = (ptrdiff_t) (elsz) / (N>>3); \
579+ for (i = 0; i < n; i++) { \
580+ uint##N##_t *d = (uint##N##_t*) (dest) + i; \
581+ const uint##N##_t *s = (const uint##N##_t*) (source) + i; \
582+ shmem_alltoalls##N(d, s, (dst) * n, (sst) * n, (size), \
583+ 0, 0, shmem_n_pes(), _py_shmem_pSync()) ; \
584+ } \
585+ return 0; \
586+ } \
587+ } while(0) /**/
588+
589+ #endif
590+
564591static
565592int shmem_alltoallsmem_x (shmem_team_t team ,
566593 void * dest , const void * source ,
567594 ptrdiff_t dst , ptrdiff_t sst ,
568595 size_t size , size_t eltsize )
569596{
597+ PySHMEM_ALLTOALLSMEM_X (64 , team , dest , source , dst , sst , size , eltsize );
598+ PySHMEM_ALLTOALLSMEM_X (32 , team , dest , source , dst , sst , size , eltsize );
570599#if defined(PySHMEM_HAVE_shmem_alltoalls )
571- switch (eltsize ) {
572- case (1 ): return shmem_uint8_alltoalls (team , (uint8_t * ) dest , (uint8_t * ) source , dst , sst , size );
573- case (2 ): return shmem_uint16_alltoalls (team , (uint16_t * ) dest , (uint16_t * ) source , dst , sst , size );
574- case (4 ): return shmem_uint32_alltoalls (team , (uint32_t * ) dest , (uint32_t * ) source , dst , sst , size );
575- case (8 ): return shmem_uint64_alltoalls (team , (uint64_t * ) dest , (uint64_t * ) source , dst , sst , size );
576- }
577- return PySHMEM_UNAVAILABLE ;
578- #else
579- if (team != SHMEM_TEAM_WORLD ) return PySHMEM_UNAVAILABLE ;
580- PySHMEM_ALLTOALLS_BIT (64 , dest , source , dst , sst , size , eltsize );
581- PySHMEM_ALLTOALLS_BIT (32 , dest , source , dst , sst , size , eltsize );
582- return PySHMEM_UNAVAILABLE ;
600+ PySHMEM_ALLTOALLSMEM_X (16 , team , dest , source , dst , sst , size , eltsize );
601+ PySHMEM_ALLTOALLSMEM_X (8 , team , dest , source , dst , sst , size , eltsize );
583602#endif
603+ return PySHMEM_UNAVAILABLE ;
584604}
585605
586606#if !defined(PySHMEM_HAVE_shmem_alltoalls )
0 commit comments