Skip to content

Commit 9305110

Browse files
author
hasegawa.kento
committed
COLL/UCC: add persistent collective calls
Signed-off-by: hasegawa.kento <hasegawa.kento@fujitsu.com>
1 parent 2dec57d commit 9305110

17 files changed

+657
-142
lines changed

ompi/mca/coll/ucc/coll_ucc.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,34 @@ struct mca_coll_ucc_module_t {
133133
mca_coll_base_module_t* previous_scatter_module;
134134
mca_coll_base_module_iscatter_fn_t previous_iscatter;
135135
mca_coll_base_module_t* previous_iscatter_module;
136+
mca_coll_base_module_allreduce_init_fn_t previous_allreduce_init;
137+
mca_coll_base_module_t *previous_allreduce_init_module;
138+
mca_coll_base_module_reduce_init_fn_t previous_reduce_init;
139+
mca_coll_base_module_t *previous_reduce_init_module;
140+
mca_coll_base_module_barrier_init_fn_t previous_barrier_init;
141+
mca_coll_base_module_t *previous_barrier_init_module;
142+
mca_coll_base_module_bcast_init_fn_t previous_bcast_init;
143+
mca_coll_base_module_t *previous_bcast_init_module;
144+
mca_coll_base_module_alltoall_init_fn_t previous_alltoall_init;
145+
mca_coll_base_module_t *previous_alltoall_init_module;
146+
mca_coll_base_module_alltoallv_init_fn_t previous_alltoallv_init;
147+
mca_coll_base_module_t *previous_alltoallv_init_module;
148+
mca_coll_base_module_allgather_init_fn_t previous_allgather_init;
149+
mca_coll_base_module_t *previous_allgather_init_module;
150+
mca_coll_base_module_allgatherv_init_fn_t previous_allgatherv_init;
151+
mca_coll_base_module_t *previous_allgatherv_init_module;
152+
mca_coll_base_module_gather_init_fn_t previous_gather_init;
153+
mca_coll_base_module_t *previous_gather_init_module;
154+
mca_coll_base_module_gatherv_init_fn_t previous_gatherv_init;
155+
mca_coll_base_module_t *previous_gatherv_init_module;
156+
mca_coll_base_module_reduce_scatter_block_init_fn_t previous_reduce_scatter_block_init;
157+
mca_coll_base_module_t *previous_reduce_scatter_block_init_module;
158+
mca_coll_base_module_reduce_scatter_init_fn_t previous_reduce_scatter_init;
159+
mca_coll_base_module_t *previous_reduce_scatter_init_module;
160+
mca_coll_base_module_scatterv_init_fn_t previous_scatterv_init;
161+
mca_coll_base_module_t *previous_scatterv_init_module;
162+
mca_coll_base_module_scatter_init_fn_t previous_scatter_init;
163+
mca_coll_base_module_t *previous_scatter_init_module;
136164
};
137165
typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t;
138166
OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t);

ompi/mca/coll/ucc/coll_ucc_allgather.c

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
#include "coll_ucc_common.h"
1212

13-
static inline ucc_status_t mca_coll_ucc_allgather_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
14-
void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
15-
mca_coll_ucc_module_t *ucc_module,
16-
ucc_coll_req_h *req,
17-
mca_coll_ucc_req_t *coll_req)
13+
static inline ucc_status_t
14+
mca_coll_ucc_allgather_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
15+
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
16+
bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req,
17+
mca_coll_ucc_req_t *coll_req)
1818
{
1919
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
2020
bool is_inplace = (MPI_IN_PLACE == sbuf);
@@ -60,6 +60,10 @@ static inline ucc_status_t mca_coll_ucc_allgather_iniz(const void *sbuf, size_t
6060
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
6161
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
6262
}
63+
if (true == persistent) {
64+
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
65+
coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT;
66+
}
6367
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
6468
return UCC_OK;
6569
fallback:
@@ -75,8 +79,7 @@ int mca_coll_ucc_allgather(const void *sbuf, size_t scount, struct ompi_datatype
7579
ucc_coll_req_h req;
7680

7781
UCC_VERBOSE(3, "running ucc allgather");
78-
COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype,
79-
rbuf, rcount, rdtype,
82+
COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false,
8083
ucc_module, &req, NULL));
8184
COLL_UCC_POST_AND_CHECK(req);
8285
COLL_UCC_CHECK(coll_ucc_req_wait(req));
@@ -99,8 +102,7 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp
99102

100103
UCC_VERBOSE(3, "running ucc iallgather");
101104
COLL_UCC_GET_REQ(coll_req);
102-
COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype,
103-
rbuf, rcount, rdtype,
105+
COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false,
104106
ucc_module, &req, coll_req));
105107
COLL_UCC_POST_AND_CHECK(req);
106108
*request = &coll_req->super;
@@ -113,3 +115,28 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp
113115
return ucc_module->previous_iallgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
114116
comm, request, ucc_module->previous_iallgather_module);
115117
}
118+
119+
int mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
120+
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
121+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
122+
ompi_request_t **request, mca_coll_base_module_t *module)
123+
{
124+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module;
125+
ucc_coll_req_h req;
126+
mca_coll_ucc_req_t *coll_req = NULL;
127+
128+
COLL_UCC_GET_REQ_PC(coll_req);
129+
UCC_VERBOSE(3, "allgather_init init %p", coll_req);
130+
COLL_UCC_CHECK(mca_coll_ucc_allgather_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, true,
131+
ucc_module, &req, coll_req));
132+
*request = &coll_req->super;
133+
return OMPI_SUCCESS;
134+
fallback:
135+
UCC_VERBOSE(3, "running fallback allgather_init");
136+
if (coll_req) {
137+
mca_coll_ucc_req_free((ompi_request_t **) &coll_req);
138+
}
139+
return ucc_module->previous_allgather_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm,
140+
info, request,
141+
ucc_module->previous_allgather_init_module);
142+
}

ompi/mca/coll/ucc/coll_ucc_allgatherv.c

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@
1010

1111
#include "coll_ucc_common.h"
1212

13-
static inline ucc_status_t mca_coll_ucc_allgatherv_iniz(const void *sbuf, size_t scount,
14-
struct ompi_datatype_t *sdtype,
15-
void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
16-
struct ompi_datatype_t *rdtype,
17-
mca_coll_ucc_module_t *ucc_module,
18-
ucc_coll_req_h *req,
19-
mca_coll_ucc_req_t *coll_req)
13+
static inline ucc_status_t
14+
mca_coll_ucc_allgatherv_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
15+
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
16+
struct ompi_datatype_t *rdtype, bool persistent,
17+
mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req,
18+
mca_coll_ucc_req_t *coll_req)
2019
{
2120
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
2221
bool is_inplace = (MPI_IN_PLACE == sbuf);
@@ -58,6 +57,10 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_iniz(const void *sbuf, size_t
5857
}
5958
};
6059

60+
if (true == persistent) {
61+
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
62+
coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT;
63+
}
6164
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
6265
return UCC_OK;
6366
fallback:
@@ -76,9 +79,8 @@ int mca_coll_ucc_allgatherv(const void *sbuf, size_t scount,
7679

7780
UCC_VERBOSE(3, "running ucc allgatherv");
7881

79-
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype,
80-
rbuf, rcounts, rdisps, rdtype,
81-
ucc_module, &req, NULL));
82+
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype,
83+
false, ucc_module, &req, NULL));
8284
COLL_UCC_POST_AND_CHECK(req);
8385
COLL_UCC_CHECK(coll_ucc_req_wait(req));
8486
return OMPI_SUCCESS;
@@ -103,9 +105,8 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount,
103105

104106
UCC_VERBOSE(3, "running ucc iallgatherv");
105107
COLL_UCC_GET_REQ(coll_req);
106-
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype,
107-
rbuf, rcounts, rdisps, rdtype,
108-
ucc_module, &req, coll_req));
108+
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype,
109+
false, ucc_module, &req, coll_req));
109110
COLL_UCC_POST_AND_CHECK(req);
110111
*request = &coll_req->super;
111112
return OMPI_SUCCESS;
@@ -118,3 +119,29 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount,
118119
rbuf, rcounts, rdisps, rdtype,
119120
comm, request, ucc_module->previous_iallgatherv_module);
120121
}
122+
123+
int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
124+
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
125+
struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm,
126+
struct ompi_info_t *info, ompi_request_t **request,
127+
mca_coll_base_module_t *module)
128+
{
129+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module;
130+
ucc_coll_req_h req;
131+
mca_coll_ucc_req_t *coll_req = NULL;
132+
133+
COLL_UCC_GET_REQ_PC(coll_req);
134+
UCC_VERBOSE(3, "allgatherv_init init %p", coll_req);
135+
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype,
136+
true, ucc_module, &req, coll_req));
137+
*request = &coll_req->super;
138+
return OMPI_SUCCESS;
139+
fallback:
140+
UCC_VERBOSE(3, "running fallback allgatherv_init");
141+
if (coll_req) {
142+
mca_coll_ucc_req_free((ompi_request_t **) &coll_req);
143+
}
144+
return ucc_module->previous_allgatherv_init(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype,
145+
comm, info, request,
146+
ucc_module->previous_allgatherv_init_module);
147+
}

ompi/mca/coll/ucc/coll_ucc_allreduce.c

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
static inline ucc_status_t mca_coll_ucc_allreduce_iniz(const void *sbuf, void *rbuf, size_t count,
1414
struct ompi_datatype_t *dtype,
15-
struct ompi_op_t *op, mca_coll_ucc_module_t *ucc_module,
15+
struct ompi_op_t *op, bool persistent,
16+
mca_coll_ucc_module_t *ucc_module,
1617
ucc_coll_req_h *req,
1718
mca_coll_ucc_req_t *coll_req)
1819
{
@@ -53,6 +54,10 @@ static inline ucc_status_t mca_coll_ucc_allreduce_iniz(const void *sbuf, void *r
5354
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
5455
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
5556
}
57+
if (true == persistent) {
58+
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
59+
coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT;
60+
}
5661
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
5762
return UCC_OK;
5863
fallback:
@@ -68,8 +73,8 @@ int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, size_t count,
6873
ucc_coll_req_h req;
6974

7075
UCC_VERBOSE(3, "running ucc allreduce");
71-
COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op,
72-
ucc_module, &req, NULL));
76+
COLL_UCC_CHECK(
77+
mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, false, ucc_module, &req, NULL));
7378
COLL_UCC_POST_AND_CHECK(req);
7479
COLL_UCC_CHECK(coll_ucc_req_wait(req));
7580
return OMPI_SUCCESS;
@@ -91,8 +96,8 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, size_t count,
9196

9297
UCC_VERBOSE(3, "running ucc iallreduce");
9398
COLL_UCC_GET_REQ(coll_req);
94-
COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op,
95-
ucc_module, &req, coll_req));
99+
COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, false, ucc_module,
100+
&req, coll_req));
96101
COLL_UCC_POST_AND_CHECK(req);
97102
*request = &coll_req->super;
98103
return OMPI_SUCCESS;
@@ -104,3 +109,27 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, size_t count,
104109
return ucc_module->previous_iallreduce(sbuf, rbuf, count, dtype, op,
105110
comm, request, ucc_module->previous_iallreduce_module);
106111
}
112+
113+
int mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count,
114+
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
115+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
116+
ompi_request_t **request, mca_coll_base_module_t *module)
117+
{
118+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module;
119+
ucc_coll_req_h req;
120+
mca_coll_ucc_req_t *coll_req = NULL;
121+
122+
COLL_UCC_GET_REQ_PC(coll_req);
123+
UCC_VERBOSE(3, "allreduce_init init %p", coll_req);
124+
COLL_UCC_CHECK(mca_coll_ucc_allreduce_iniz(sbuf, rbuf, count, dtype, op, true, ucc_module, &req,
125+
coll_req));
126+
*request = &coll_req->super;
127+
return OMPI_SUCCESS;
128+
fallback:
129+
UCC_VERBOSE(3, "running fallback allreduce_init");
130+
if (coll_req) {
131+
mca_coll_ucc_req_free((ompi_request_t **) &coll_req);
132+
}
133+
return ucc_module->previous_allreduce_init(sbuf, rbuf, count, dtype, op, comm, info, request,
134+
ucc_module->previous_allreduce_init_module);
135+
}

ompi/mca/coll/ucc/coll_ucc_alltoall.c

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
#include "coll_ucc_common.h"
1212

13-
static inline ucc_status_t mca_coll_ucc_alltoall_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
14-
void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
15-
mca_coll_ucc_module_t *ucc_module,
16-
ucc_coll_req_h *req,
17-
mca_coll_ucc_req_t *coll_req)
13+
static inline ucc_status_t
14+
mca_coll_ucc_alltoall_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
15+
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
16+
bool persistent, mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req,
17+
mca_coll_ucc_req_t *coll_req)
1818
{
1919
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
2020
bool is_inplace = (MPI_IN_PLACE == sbuf);
@@ -60,6 +60,10 @@ static inline ucc_status_t mca_coll_ucc_alltoall_iniz(const void *sbuf, size_t s
6060
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
6161
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
6262
}
63+
if (true == persistent) {
64+
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
65+
coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT;
66+
}
6367
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
6468
return UCC_OK;
6569
fallback:
@@ -75,8 +79,7 @@ int mca_coll_ucc_alltoall(const void *sbuf, size_t scount, struct ompi_datatype_
7579
ucc_coll_req_h req;
7680

7781
UCC_VERBOSE(3, "running ucc alltoall");
78-
COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype,
79-
rbuf, rcount, rdtype,
82+
COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false,
8083
ucc_module, &req, NULL));
8184
COLL_UCC_POST_AND_CHECK(req);
8285
COLL_UCC_CHECK(coll_ucc_req_wait(req));
@@ -99,8 +102,7 @@ int mca_coll_ucc_ialltoall(const void *sbuf, size_t scount, struct ompi_datatype
99102

100103
UCC_VERBOSE(3, "running ucc ialltoall");
101104
COLL_UCC_GET_REQ(coll_req);
102-
COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype,
103-
rbuf, rcount, rdtype,
105+
COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, false,
104106
ucc_module, &req, coll_req));
105107
COLL_UCC_POST_AND_CHECK(req);
106108
*request = &coll_req->super;
@@ -113,3 +115,28 @@ int mca_coll_ucc_ialltoall(const void *sbuf, size_t scount, struct ompi_datatype
113115
return ucc_module->previous_ialltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype,
114116
comm, request, ucc_module->previous_ialltoall_module);
115117
}
118+
119+
int mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
120+
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
121+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
122+
ompi_request_t **request, mca_coll_base_module_t *module)
123+
{
124+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module;
125+
ucc_coll_req_h req;
126+
mca_coll_ucc_req_t *coll_req = NULL;
127+
128+
COLL_UCC_GET_REQ_PC(coll_req);
129+
UCC_VERBOSE(3, "alltoall_init init %p", coll_req);
130+
COLL_UCC_CHECK(mca_coll_ucc_alltoall_iniz(sbuf, scount, sdtype, rbuf, rcount, rdtype, true,
131+
ucc_module, &req, coll_req));
132+
*request = &coll_req->super;
133+
return OMPI_SUCCESS;
134+
fallback:
135+
UCC_VERBOSE(3, "running fallback alltoall_init");
136+
if (coll_req) {
137+
mca_coll_ucc_req_free((ompi_request_t **) &coll_req);
138+
}
139+
return ucc_module->previous_alltoall_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm,
140+
info, request,
141+
ucc_module->previous_alltoall_init_module);
142+
}

0 commit comments

Comments
 (0)