Skip to content

Commit f3a7899

Browse files
authored
Merge pull request #13374 from mentOS31/add_persistent_collective_calls_for_ucc
COLL/UCC: add persistent collective calls for UCC
2 parents e2a2583 + a8b244e commit f3a7899

18 files changed

+914
-207
lines changed

ompi/mca/coll/ucc/coll_ucc.h

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/**
22
Copyright (c) 2021 Mellanox Technologies. All rights reserved.
33
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
4+
Copyright (c) 2025 Fujitsu Limited. All rights reserved.
45
$COPYRIGHT$
56
67
Additional copyrights may follow
@@ -61,6 +62,7 @@ struct mca_coll_ucc_component_t {
6162
ucc_lib_attr_t ucc_lib_attr;
6263
ucc_coll_type_t cts_requested;
6364
ucc_coll_type_t nb_cts_requested;
65+
ucc_coll_type_t ps_cts_requested;
6466
ucc_context_h ucc_context;
6567
opal_free_list_t requests;
6668
};
@@ -132,6 +134,34 @@ struct mca_coll_ucc_module_t {
132134
mca_coll_base_module_t* previous_scatter_module;
133135
mca_coll_base_module_iscatter_fn_t previous_iscatter;
134136
mca_coll_base_module_t* previous_iscatter_module;
137+
mca_coll_base_module_allreduce_init_fn_t previous_allreduce_init;
138+
mca_coll_base_module_t* previous_allreduce_init_module;
139+
mca_coll_base_module_reduce_init_fn_t previous_reduce_init;
140+
mca_coll_base_module_t* previous_reduce_init_module;
141+
mca_coll_base_module_barrier_init_fn_t previous_barrier_init;
142+
mca_coll_base_module_t* previous_barrier_init_module;
143+
mca_coll_base_module_bcast_init_fn_t previous_bcast_init;
144+
mca_coll_base_module_t* previous_bcast_init_module;
145+
mca_coll_base_module_alltoall_init_fn_t previous_alltoall_init;
146+
mca_coll_base_module_t* previous_alltoall_init_module;
147+
mca_coll_base_module_alltoallv_init_fn_t previous_alltoallv_init;
148+
mca_coll_base_module_t* previous_alltoallv_init_module;
149+
mca_coll_base_module_allgather_init_fn_t previous_allgather_init;
150+
mca_coll_base_module_t* previous_allgather_init_module;
151+
mca_coll_base_module_allgatherv_init_fn_t previous_allgatherv_init;
152+
mca_coll_base_module_t* previous_allgatherv_init_module;
153+
mca_coll_base_module_gather_init_fn_t previous_gather_init;
154+
mca_coll_base_module_t* previous_gather_init_module;
155+
mca_coll_base_module_gatherv_init_fn_t previous_gatherv_init;
156+
mca_coll_base_module_t* previous_gatherv_init_module;
157+
mca_coll_base_module_reduce_scatter_block_init_fn_t previous_reduce_scatter_block_init;
158+
mca_coll_base_module_t* previous_reduce_scatter_block_init_module;
159+
mca_coll_base_module_reduce_scatter_init_fn_t previous_reduce_scatter_init;
160+
mca_coll_base_module_t* previous_reduce_scatter_init_module;
161+
mca_coll_base_module_scatterv_init_fn_t previous_scatterv_init;
162+
mca_coll_base_module_t* previous_scatterv_init_module;
163+
mca_coll_base_module_scatter_init_fn_t previous_scatter_init;
164+
mca_coll_base_module_t* previous_scatter_init_module;
135165
};
136166
typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t;
137167
OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t);
@@ -305,5 +335,78 @@ int mca_coll_ucc_iscatter(const void *sbuf, size_t scount,
305335
ompi_request_t** request,
306336
mca_coll_base_module_t *module);
307337

338+
int mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count,
339+
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
340+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
341+
ompi_request_t **request, mca_coll_base_module_t *module);
342+
343+
int mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf, size_t count,
344+
struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root,
345+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
346+
ompi_request_t **request, mca_coll_base_module_t *module);
347+
348+
int mca_coll_ucc_barrier_init(struct ompi_communicator_t *comm, struct ompi_info_t *info,
349+
ompi_request_t **request, mca_coll_base_module_t *module);
350+
351+
int mca_coll_ucc_bcast_init(void *buff, size_t count, struct ompi_datatype_t *datatype, int root,
352+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
353+
ompi_request_t **request, mca_coll_base_module_t *module);
354+
355+
int mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
356+
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
357+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
358+
ompi_request_t **request, mca_coll_base_module_t *module);
359+
360+
int mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_count_array_t scounts,
361+
ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype,
362+
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
363+
struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm,
364+
struct ompi_info_t *info, ompi_request_t **request,
365+
mca_coll_base_module_t *module);
366+
367+
int mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
368+
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
369+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
370+
ompi_request_t **request, mca_coll_base_module_t *module);
371+
372+
int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
373+
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps,
374+
struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm,
375+
struct ompi_info_t *info, ompi_request_t **request,
376+
mca_coll_base_module_t *module);
377+
378+
int mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
379+
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root,
380+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
381+
ompi_request_t **request, mca_coll_base_module_t *module);
382+
383+
int mca_coll_ucc_gatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
384+
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps,
385+
struct ompi_datatype_t *rdtype, int root,
386+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
387+
ompi_request_t **request, mca_coll_base_module_t *module);
388+
389+
int mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf, size_t rcount,
390+
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
391+
struct ompi_communicator_t *comm,
392+
struct ompi_info_t *info, ompi_request_t **request,
393+
mca_coll_base_module_t *module);
394+
395+
int mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi_count_array_t rcounts,
396+
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
397+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
398+
ompi_request_t **request, mca_coll_base_module_t *module);
399+
400+
int mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t scounts,
401+
ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, void *rbuf,
402+
size_t rcount, struct ompi_datatype_t *rdtype, int root,
403+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
404+
ompi_request_t **request, mca_coll_base_module_t *module);
405+
406+
int mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
407+
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root,
408+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
409+
ompi_request_t **request, mca_coll_base_module_t *module);
410+
308411
END_C_DECLS
309412
#endif

ompi/mca/coll/ucc/coll_ucc_allgather.c

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
/**
33
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
4+
* Copyright (c) 2025 Fujitsu Limited. All rights reserved.
45
* $COPYRIGHT$
56
*
67
* Additional copyrights may follow
@@ -9,15 +10,17 @@
910

1011
#include "coll_ucc_common.h"
1112

12-
static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
13-
void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
14-
mca_coll_ucc_module_t *ucc_module,
15-
ucc_coll_req_h *req,
16-
mca_coll_ucc_req_t *coll_req)
13+
static inline ucc_status_t
14+
mca_coll_ucc_allgather_init_common(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
15+
void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
16+
bool persistent, mca_coll_ucc_module_t *ucc_module,
17+
ucc_coll_req_h *req,
18+
mca_coll_ucc_req_t *coll_req)
1719
{
1820
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
1921
bool is_inplace = (MPI_IN_PLACE == sbuf);
2022
int comm_size = ompi_comm_size(ucc_module->comm);
23+
uint64_t flags = 0;
2124

2225
if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) ||
2326
!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
@@ -37,9 +40,12 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t
3740
goto fallback;
3841
}
3942

43+
flags = (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) |
44+
(persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0);
45+
4046
ucc_coll_args_t coll = {
41-
.mask = 0,
42-
.flags = 0,
47+
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
48+
.flags = flags,
4349
.coll_type = UCC_COLL_TYPE_ALLGATHER,
4450
.src.info = {
4551
.buffer = (void*)sbuf,
@@ -55,10 +61,6 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t
5561
}
5662
};
5763

58-
if (is_inplace) {
59-
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
60-
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
61-
}
6264
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
6365
return UCC_OK;
6466
fallback:
@@ -74,9 +76,9 @@ int mca_coll_ucc_allgather(const void *sbuf, size_t scount, struct ompi_datatype
7476
ucc_coll_req_h req;
7577

7678
UCC_VERBOSE(3, "running ucc allgather");
77-
COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype,
78-
rbuf, rcount, rdtype,
79-
ucc_module, &req, NULL));
79+
COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype,
80+
rbuf, rcount, rdtype,
81+
false, ucc_module, &req, NULL));
8082
COLL_UCC_POST_AND_CHECK(req);
8183
COLL_UCC_CHECK(coll_ucc_req_wait(req));
8284
return OMPI_SUCCESS;
@@ -98,9 +100,9 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp
98100

99101
UCC_VERBOSE(3, "running ucc iallgather");
100102
COLL_UCC_GET_REQ(coll_req);
101-
COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype,
102-
rbuf, rcount, rdtype,
103-
ucc_module, &req, coll_req));
103+
COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype,
104+
rbuf, rcount, rdtype,
105+
false, ucc_module, &req, coll_req));
104106
COLL_UCC_POST_AND_CHECK(req);
105107
*request = &coll_req->super;
106108
return OMPI_SUCCESS;
@@ -112,3 +114,29 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp
112114
return ucc_module->previous_iallgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
113115
comm, request, ucc_module->previous_iallgather_module);
114116
}
117+
118+
int mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
119+
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
120+
struct ompi_communicator_t *comm, struct ompi_info_t *info,
121+
ompi_request_t **request, mca_coll_base_module_t *module)
122+
{
123+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module;
124+
ucc_coll_req_h req;
125+
mca_coll_ucc_req_t *coll_req = NULL;
126+
127+
COLL_UCC_GET_REQ_PERSISTENT(coll_req);
128+
UCC_VERBOSE(3, "allgather_init init %p", coll_req);
129+
COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype,
130+
rbuf, rcount, rdtype,
131+
true, ucc_module, &req, coll_req));
132+
*request = &coll_req->super;
133+
return OMPI_SUCCESS;
134+
fallback:
135+
UCC_VERBOSE(3, "running fallback allgather_init");
136+
if (coll_req) {
137+
mca_coll_ucc_req_free((ompi_request_t **) &coll_req);
138+
}
139+
return ucc_module->previous_allgather_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm,
140+
info, request,
141+
ucc_module->previous_allgather_init_module);
142+
}

ompi/mca/coll/ucc/coll_ucc_allgatherv.c

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
/**
33
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
4+
* Copyright (c) 2025 Fujitsu Limited. All rights reserved.
45
* $COPYRIGHT$
56
*
67
* Additional copyrights may follow
@@ -9,13 +10,14 @@
910

1011
#include "coll_ucc_common.h"
1112

12-
static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount,
13-
struct ompi_datatype_t *sdtype,
14-
void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
15-
struct ompi_datatype_t *rdtype,
16-
mca_coll_ucc_module_t *ucc_module,
17-
ucc_coll_req_h *req,
18-
mca_coll_ucc_req_t *coll_req)
13+
static inline ucc_status_t
14+
mca_coll_ucc_allgatherv_init_common(const void *sbuf, size_t scount,
15+
struct ompi_datatype_t *sdtype,
16+
void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
17+
struct ompi_datatype_t *rdtype,
18+
bool persistent, mca_coll_ucc_module_t *ucc_module,
19+
ucc_coll_req_h *req,
20+
mca_coll_ucc_req_t *coll_req)
1921
{
2022
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
2123
bool is_inplace = (MPI_IN_PLACE == sbuf);
@@ -36,7 +38,8 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t
3638

3739
flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) |
3840
(ompi_disp_array_is_64bit(rdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) |
39-
(is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0);
41+
(is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) |
42+
(persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0);
4043

4144
ucc_coll_args_t coll = {
4245
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
@@ -75,9 +78,9 @@ int mca_coll_ucc_allgatherv(const void *sbuf, size_t scount,
7578

7679
UCC_VERBOSE(3, "running ucc allgatherv");
7780

78-
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype,
79-
rbuf, rcounts, rdisps, rdtype,
80-
ucc_module, &req, NULL));
81+
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype,
82+
rbuf, rcounts, rdisps, rdtype,
83+
false, ucc_module, &req, NULL));
8184
COLL_UCC_POST_AND_CHECK(req);
8285
COLL_UCC_CHECK(coll_ucc_req_wait(req));
8386
return OMPI_SUCCESS;
@@ -102,9 +105,9 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount,
102105

103106
UCC_VERBOSE(3, "running ucc iallgatherv");
104107
COLL_UCC_GET_REQ(coll_req);
105-
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype,
106-
rbuf, rcounts, rdisps, rdtype,
107-
ucc_module, &req, coll_req));
108+
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype,
109+
rbuf, rcounts, rdisps, rdtype,
110+
false, ucc_module, &req, coll_req));
108111
COLL_UCC_POST_AND_CHECK(req);
109112
*request = &coll_req->super;
110113
return OMPI_SUCCESS;
@@ -117,3 +120,30 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount,
117120
rbuf, rcounts, rdisps, rdtype,
118121
comm, request, ucc_module->previous_iallgatherv_module);
119122
}
123+
124+
int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
125+
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
126+
struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm,
127+
struct ompi_info_t *info, ompi_request_t **request,
128+
mca_coll_base_module_t *module)
129+
{
130+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module;
131+
ucc_coll_req_h req;
132+
mca_coll_ucc_req_t *coll_req = NULL;
133+
134+
COLL_UCC_GET_REQ_PERSISTENT(coll_req);
135+
UCC_VERBOSE(3, "allgatherv_init init %p", coll_req);
136+
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype,
137+
rbuf, rcounts, rdisps, rdtype,
138+
true, ucc_module, &req, coll_req));
139+
*request = &coll_req->super;
140+
return OMPI_SUCCESS;
141+
fallback:
142+
UCC_VERBOSE(3, "running fallback allgatherv_init");
143+
if (coll_req) {
144+
mca_coll_ucc_req_free((ompi_request_t **) &coll_req);
145+
}
146+
return ucc_module->previous_allgatherv_init(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype,
147+
comm, info, request,
148+
ucc_module->previous_allgatherv_init_module);
149+
}

0 commit comments

Comments
 (0)