@@ -17,27 +17,35 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om
1717 ucc_coll_req_h * req ,
1818 mca_coll_ucc_req_t * coll_req )
1919{
20- ucc_datatype_t ucc_sdt , ucc_rdt ;
20+ ucc_datatype_t ucc_sdt = UCC_DT_INT8 , ucc_rdt = UCC_DT_INT8 ;
21+ bool is_inplace = (MPI_IN_PLACE == sbuf );
2122 int comm_rank = ompi_comm_rank (ucc_module -> comm );
2223 int comm_size = ompi_comm_size (ucc_module -> comm );
2324
24- if (!ompi_datatype_is_contiguous_memory_layout (sdtype , scount )) {
25- goto fallback ;
26- }
27- ucc_sdt = ompi_dtype_to_ucc_dtype (sdtype );
2825 if (comm_rank == root ) {
29- if (!ompi_datatype_is_contiguous_memory_layout (rdtype , rcount * comm_size )) {
26+ if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout (sdtype , scount )) ||
27+ !ompi_datatype_is_contiguous_memory_layout (rdtype , rcount * comm_size )) {
3028 goto fallback ;
3129 }
30+
3231 ucc_rdt = ompi_dtype_to_ucc_dtype (rdtype );
33- if ((COLL_UCC_DT_UNSUPPORTED == ucc_rdt ) ||
34- (MPI_IN_PLACE != sbuf && COLL_UCC_DT_UNSUPPORTED == ucc_sdt )) {
32+ if (!is_inplace ) {
33+ ucc_sdt = ompi_dtype_to_ucc_dtype (sdtype );
34+ }
35+
36+ if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt ) ||
37+ (COLL_UCC_DT_UNSUPPORTED == ucc_rdt )) {
3538 UCC_VERBOSE (5 , "ompi_datatype is not supported: dtype = %s" ,
36- (COLL_UCC_DT_UNSUPPORTED == ucc_rdt ) ?
37- rdtype -> super .name : sdtype -> super .name );
39+ (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ) ?
40+ sdtype -> super .name : rdtype -> super .name );
3841 goto fallback ;
3942 }
4043 } else {
44+ if (!ompi_datatype_is_contiguous_memory_layout (sdtype , scount )) {
45+ goto fallback ;
46+ }
47+
48+ ucc_sdt = ompi_dtype_to_ucc_dtype (sdtype );
4149 if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ) {
4250 UCC_VERBOSE (5 , "ompi_datatype is not supported: dtype = %s" ,
4351 sdtype -> super .name );
@@ -64,7 +72,7 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om
6472 },
6573 };
6674
67- if (MPI_IN_PLACE == sbuf ) {
75+ if (is_inplace ) {
6876 coll .mask |= UCC_COLL_ARGS_FIELD_FLAGS ;
6977 coll .flags = UCC_COLL_ARGS_FLAG_IN_PLACE ;
7078 }
0 commit comments