1616
1717#include "opal/runtime/opal.h"
1818#include "opal/mca/pmix/pmix.h"
19+ #include "ompi/attribute/attribute.h"
1920#include "ompi/message/message.h"
2021#include "ompi/mca/pml/base/pml_base_bsend.h"
2122#include "opal/mca/common/ucx/common_ucx.h"
@@ -190,9 +191,9 @@ int mca_pml_ucx_close(void)
190191int mca_pml_ucx_init (void )
191192{
192193 ucp_worker_params_t params ;
193- ucs_status_t status ;
194194 ucp_worker_attr_t attr ;
195- int rc ;
195+ ucs_status_t status ;
196+ int i , rc ;
196197
197198 PML_UCX_VERBOSE (1 , "mca_pml_ucx_init" );
198199
@@ -209,30 +210,34 @@ int mca_pml_ucx_init(void)
209210 & ompi_pml_ucx .ucp_worker );
210211 if (UCS_OK != status ) {
211212 PML_UCX_ERROR ("Failed to create UCP worker" );
212- return OMPI_ERROR ;
213+ rc = OMPI_ERROR ;
214+ goto err ;
213215 }
214216
215217 attr .field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE ;
216218 status = ucp_worker_query (ompi_pml_ucx .ucp_worker , & attr );
217219 if (UCS_OK != status ) {
218- ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
219- ompi_pml_ucx .ucp_worker = NULL ;
220220 PML_UCX_ERROR ("Failed to query UCP worker thread level" );
221- return OMPI_ERROR ;
221+ rc = OMPI_ERROR ;
222+ goto err_destroy_worker ;
222223 }
223224
224- if (ompi_mpi_thread_multiple && attr .thread_mode != UCS_THREAD_MODE_MULTI ) {
225+ if (ompi_mpi_thread_multiple && ( attr .thread_mode != UCS_THREAD_MODE_MULTI ) ) {
225226 /* UCX does not support multithreading, disqualify current PML for now */
226227 /* TODO: we should let OMPI to fallback to THREAD_SINGLE mode */
227- ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
228- ompi_pml_ucx .ucp_worker = NULL ;
229228 PML_UCX_ERROR ("UCP worker does not support MPI_THREAD_MULTIPLE" );
230- return OMPI_ERROR ;
229+ rc = OMPI_ERR_NOT_SUPPORTED ;
230+ goto err_destroy_worker ;
231231 }
232232
233233 rc = mca_pml_ucx_send_worker_address ();
234234 if (rc < 0 ) {
235- return rc ;
235+ goto err_destroy_worker ;
236+ }
237+
238+ ompi_pml_ucx .datatype_attr_keyval = MPI_KEYVAL_INVALID ;
239+ for (i = 0 ; i < OMPI_DATATYPE_MAX_PREDEFINED ; ++ i ) {
240+ ompi_pml_ucx .predefined_types [i ] = PML_UCX_DATATYPE_INVALID ;
236241 }
237242
238243 /* Initialize the free lists */
@@ -249,14 +254,33 @@ int mca_pml_ucx_init(void)
249254 (void * )ompi_pml_ucx .ucp_context ,
250255 (void * )ompi_pml_ucx .ucp_worker );
251256 return OMPI_SUCCESS ;
257+
258+ err_destroy_worker :
259+ ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
260+ ompi_pml_ucx .ucp_worker = NULL ;
261+ err :
262+ return OMPI_ERROR ;
252263}
253264
254265int mca_pml_ucx_cleanup (void )
255266{
267+ int i ;
268+
256269 PML_UCX_VERBOSE (1 , "mca_pml_ucx_cleanup" );
257270
258271 opal_progress_unregister (mca_pml_ucx_progress );
259272
273+ if (ompi_pml_ucx .datatype_attr_keyval != MPI_KEYVAL_INVALID ) {
274+ ompi_attr_free_keyval (TYPE_ATTR , & ompi_pml_ucx .datatype_attr_keyval , false);
275+ }
276+
277+ for (i = 0 ; i < OMPI_DATATYPE_MAX_PREDEFINED ; ++ i ) {
278+ if (ompi_pml_ucx .predefined_types [i ] != PML_UCX_DATATYPE_INVALID ) {
279+ ucp_dt_destroy (ompi_pml_ucx .predefined_types [i ]);
280+ ompi_pml_ucx .predefined_types [i ] = PML_UCX_DATATYPE_INVALID ;
281+ }
282+ }
283+
260284 ompi_pml_ucx .completed_send_req .req_state = OMPI_REQUEST_INVALID ;
261285 OMPI_REQUEST_FINI (& ompi_pml_ucx .completed_send_req );
262286 OBJ_DESTRUCT (& ompi_pml_ucx .completed_send_req );
@@ -398,6 +422,22 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
398422
399423int mca_pml_ucx_enable (bool enable )
400424{
425+ ompi_attribute_fn_ptr_union_t copy_fn ;
426+ ompi_attribute_fn_ptr_union_t del_fn ;
427+ int ret ;
428+
429+ /* Create a key for adding custom attributes to datatypes */
430+ copy_fn .attr_datatype_copy_fn =
431+ (MPI_Type_internal_copy_attr_function * )MPI_TYPE_NULL_COPY_FN ;
432+ del_fn .attr_datatype_delete_fn = mca_pml_ucx_datatype_attr_del_fn ;
433+ ret = ompi_attr_create_keyval (TYPE_ATTR , copy_fn , del_fn ,
434+ & ompi_pml_ucx .datatype_attr_keyval , NULL , 0 ,
435+ NULL );
436+ if (ret != OMPI_SUCCESS ) {
437+ PML_UCX_ERROR ("Failed to create keyval for UCX datatypes: %d" , ret );
438+ return ret ;
439+ }
440+
401441 PML_UCX_FREELIST_INIT (& ompi_pml_ucx .persistent_reqs ,
402442 mca_pml_ucx_persistent_request_t ,
403443 128 , -1 , 128 );
0 commit comments