Skip to content

Commit 4171fdb

Browse files
author
hasegawa.kento
committed
COLL/UCC: add <coll>_init to -mca coll_ucc_cts parser
Signed-off-by: hasegawa.kento <hasegawa.kento@fujitsu.com>
1 parent 9305110 commit 4171fdb

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

ompi/mca/coll/ucc/coll_ucc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ struct mca_coll_ucc_component_t {
6262
ucc_lib_attr_t ucc_lib_attr;
6363
ucc_coll_type_t cts_requested;
6464
ucc_coll_type_t nb_cts_requested;
65+
ucc_coll_type_t pc_cts_requested;
6566
ucc_context_h ucc_context;
6667
opal_free_list_t requests;
6768
};

ompi/mca/coll/ucc/coll_ucc_component.c

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
44
* Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
55
* Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
6+
* Copyright (c) 2025 Fujitsu Limited. All rights reserved.
67
* $COPYRIGHT$
78
*
89
* Additional copyrights may follow
@@ -143,6 +144,63 @@ static ucc_coll_type_t mca_coll_ucc_str_to_type(const char *str)
143144
return UCC_COLL_TYPE_LAST;
144145
}
145146

147+
/* is a persistent collective */
148+
static inline int mca_coll_ucc_init_cts_is_p(const char *cp, char *bp, size_t bz)
149+
{
150+
size_t len = strlen(cp), len_suffix = sizeof("_init") - 1;
151+
152+
if ((bz > 0) && (bp != 0)) {
153+
bp[0] = '\0';
154+
}
155+
/* check if it is a persistent collective */
156+
if (len > len_suffix) {
157+
size_t blen = len - len_suffix;
158+
const char *cp_suffix = &cp[blen];
159+
160+
if (0 == strcmp(cp_suffix, "_init")) {
161+
if ((bz > 0) && (bp != 0)) {
162+
if (blen >= bz) {
163+
return 0 /* XXX internal error */;
164+
}
165+
strncpy(bp, cp, blen);
166+
bp[blen] = '\0';
167+
}
168+
return 1 /* true */;
169+
}
170+
}
171+
return 0 /* false */;
172+
}
173+
174+
/* is an alias (special) name */
175+
static inline int mca_coll_ucc_init_cts_is_a(const char *cp, bool disable,
176+
mca_coll_ucc_component_t *cm)
177+
{
178+
if (0 == strcmp(cp, "colls_b")) { /* all blocking colls */
179+
if (disable) {
180+
cm->cts_requested &= ~COLL_UCC_CTS;
181+
} else {
182+
cm->cts_requested |= COLL_UCC_CTS;
183+
}
184+
return 1 /* true */;
185+
} else if ((0 == strcmp(cp, "colls_i")) || (0 == strcmp(cp, "colls_nb"))) {
186+
/* all non-blocking colls */
187+
if (disable) {
188+
cm->nb_cts_requested &= ~COLL_UCC_CTS;
189+
} else {
190+
cm->nb_cts_requested |= COLL_UCC_CTS;
191+
}
192+
return 1 /* true */;
193+
} else if (0 == strcmp(cp, "colls_p")) { /* all persistent colls */
194+
if (disable) {
195+
cm->pc_cts_requested &= ~COLL_UCC_CTS;
196+
} else {
197+
cm->pc_cts_requested |= COLL_UCC_CTS;
198+
}
199+
return 1 /* true */;
200+
}
201+
return 0 /* false */;
202+
}
203+
146204
static void mca_coll_ucc_init_default_cts(void)
147205
{
148206
mca_coll_ucc_component_t *cm = &mca_coll_ucc_component;
@@ -157,18 +215,33 @@ static void mca_coll_ucc_init_default_cts(void)
157215
n_cts = opal_argv_count(cts);
158216
cm->cts_requested = disable ? COLL_UCC_CTS : 0;
159217
cm->nb_cts_requested = disable ? COLL_UCC_CTS : 0;
218+
cm->pc_cts_requested = 0; /* XXX PC currently disabled by default */
160219
for (i = 0; i < n_cts; i++) {
220+
char l_str[64]; /* XXX sizeof("reduce_scatter_block") */
221+
size_t l_stz = sizeof(l_str);
222+
223+
if (0 < mca_coll_ucc_init_cts_is_a(cts[i], disable, cm)) {
224+
continue;
225+
}
161226
if (('i' == cts[i][0]) || ('I' == cts[i][0])) {
162227
/* non blocking collective setting */
163228
str = cts[i] + 1;
164229
ct = &cm->nb_cts_requested;
230+
} else if (0 < mca_coll_ucc_init_cts_is_p(cts[i], l_str, l_stz)) {
231+
/* persistent collective setting */
232+
str = l_str;
233+
ct = &cm->pc_cts_requested;
165234
} else {
166235
str = cts[i];
167236
ct = &cm->cts_requested;
168237
}
169238
c = mca_coll_ucc_str_to_type(str);
170239
if (UCC_COLL_TYPE_LAST == c) {
171-
*ct = COLL_UCC_CTS;
240+
if (&cm->pc_cts_requested != ct) {
241+
*ct = COLL_UCC_CTS;
242+
} else {
243+
*ct = 0; /* XXX PC currently disabled by default */
244+
}
172245
break;
173246
}
174247
if (disable) {

ompi/mca/coll/ucc/coll_ucc_module.c

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,12 @@ static inline ucc_ep_map_t get_rank_map(struct ompi_communicator_t *comm)
425425
MCA_COLL_INSTALL_API(__comm, i##__api, mca_coll_ucc_i##__api, &__ucc_module->super, "ucc"); \
426426
(__ucc_module)->super.coll_i##__api = mca_coll_ucc_i##__api; \
427427
} \
428+
if (mca_coll_ucc_component.pc_cts_requested & UCC_COLL_TYPE_##__COLL) \
429+
{ \
430+
MCA_COLL_SAVE_API(__comm, __api##_init, (__ucc_module)->previous_##__api##_init, (__ucc_module)->previous_##__api##_init_module, "ucc"); \
431+
MCA_COLL_INSTALL_API(__comm, __api##_init, mca_coll_ucc_##__api##_init, &__ucc_module->super, "ucc"); \
432+
(__ucc_module)->super.coll_##__api##_init = mca_coll_ucc_##__api##_init; \
433+
} \
428434
} \
429435
} while (0)
430436

@@ -559,11 +565,32 @@ mca_coll_ucc_module_disable(mca_coll_base_module_t *module,
559565
UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce);
560566
UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce);
561567
UCC_UNINSTALL_COLL_API(comm, ucc_module, gather);
568+
/* UCC_UNINSTALL_COLL_API(comm, ucc_module, igather); */
562569
UCC_UNINSTALL_COLL_API(comm, ucc_module, gatherv);
570+
/* UCC_UNINSTALL_COLL_API(comm, ucc_module, igatherv); */
563571
UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_block);
572+
/* UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce_scatter_block); */
564573
UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter);
574+
/* UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce_scatter); */
565575
UCC_UNINSTALL_COLL_API(comm, ucc_module, scatter);
576+
/* UCC_UNINSTALL_COLL_API(comm, ucc_module, iscatter); */
566577
UCC_UNINSTALL_COLL_API(comm, ucc_module, scatterv);
578+
/* UCC_UNINSTALL_COLL_API(comm, ucc_module, iscatterv); */
579+
580+
UCC_UNINSTALL_COLL_API(comm, ucc_module, allreduce_init);
581+
UCC_UNINSTALL_COLL_API(comm, ucc_module, barrier_init);
582+
UCC_UNINSTALL_COLL_API(comm, ucc_module, bcast_init);
583+
UCC_UNINSTALL_COLL_API(comm, ucc_module, alltoall_init);
584+
UCC_UNINSTALL_COLL_API(comm, ucc_module, alltoallv_init);
585+
UCC_UNINSTALL_COLL_API(comm, ucc_module, allgather_init);
586+
UCC_UNINSTALL_COLL_API(comm, ucc_module, allgatherv_init);
587+
UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_init);
588+
UCC_UNINSTALL_COLL_API(comm, ucc_module, gather_init);
589+
UCC_UNINSTALL_COLL_API(comm, ucc_module, gatherv_init);
590+
UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_block_init);
591+
UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_init);
592+
UCC_UNINSTALL_COLL_API(comm, ucc_module, scatter_init);
593+
UCC_UNINSTALL_COLL_API(comm, ucc_module, scatterv_init);
567594

568595
return OMPI_SUCCESS;
569596
}

0 commit comments

Comments
 (0)