@@ -487,9 +487,9 @@ static int ompi_osc_ucx_shared_query_peer(ompi_osc_ucx_module_t *module, int pee
487487 if (UCS_OK != ucp_rkey_ptr (rkey , module -> addrs [peer ], & addr_p )) {
488488 return OMPI_ERR_NOT_AVAILABLE ;
489489 }
490- * size = module -> sizes [peer ];
491- * ((void * * ) baseptr ) = ( void * ) module -> shmem_addrs [ peer ] ;
492- * disp_unit = module -> disp_units [peer ];
490+ * size = module -> same_size ? module -> size : module -> sizes [peer ];
491+ * ((void * * ) baseptr ) = addr_p ;
492+ * disp_unit = ( module -> disp_unit < 0 ) ? module -> disp_units [peer ] : module -> disp_unit ;
493493
494494 return OMPI_SUCCESS ;
495495}
@@ -554,8 +554,9 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt
554554 int flavor , int * model ) {
555555 ompi_osc_ucx_module_t * module = NULL ;
556556 char * name = NULL ;
557- long values [2 ];
557+ long values [4 ];
558558 int ret = OMPI_SUCCESS ;
559+ int val_count = 0 ;
559560 int i , comm_size = ompi_comm_size (comm );
560561 bool env_initialized = false;
561562 void * state_base = NULL ;
@@ -679,42 +680,70 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt
679680 module -> acc_single_intrinsic = check_config_value_bool ("acc_single_intrinsic" , info );
680681 module -> skip_sync_check = false;
681682
682- /**
683- * TODO: we need to collect the shared memory information from all processes
684- * on the same node. This includes the size and base address, which needs
685- * to be passed to ucp_rkey_ptr().
686- */
687- module -> shmem_info = NULL ;
688-
689683 /* share everyone's displacement units. Only do an allgather if
690684 strictly necessary, since it requires O(p) state. */
691685 values [0 ] = disp_unit ;
692686 values [1 ] = - disp_unit ;
687+ values [2 ] = size ;
688+ values [3 ] = - (int64_t )size ;
693689
694- ret = module -> comm -> c_coll -> coll_allreduce (MPI_IN_PLACE , values , 2 , MPI_LONG ,
690+ ret = module -> comm -> c_coll -> coll_allreduce (MPI_IN_PLACE , values , 4 , MPI_LONG ,
695691 MPI_MIN , module -> comm ,
696692 module -> comm -> c_coll -> coll_allreduce_module );
697693 if (OMPI_SUCCESS != ret ) {
698694 goto error ;
699695 }
700696
701- if (values [0 ] == - values [1 ]) { /* everyone has the same disp_unit, we do not need O(p) space */
697+ bool same_disp_unit = (values [0 ] == - values [1 ]);
698+ bool same_size = (values [2 ] == - values [3 ]);
699+
700+ if (same_disp_unit ) { /* everyone has the same disp_unit, we do not need O(p) space */
702701 module -> disp_unit = disp_unit ;
703- } else { /* different disp_unit sizes, allocate O(p) space to store them */
704- module -> disp_unit = -1 ;
705- module -> disp_units = calloc (comm_size , sizeof (ptrdiff_t ));
706- if (module -> disp_units == NULL ) {
707- ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE ;
708- goto error ;
709- }
702+ module -> disp_units = NULL ;
703+ values [val_count ++ ] = disp_unit ;
704+ }
705+
706+ if (same_size ) {
707+ module -> same_size = true;
708+ module -> sizes = NULL ;
709+ values [val_count ++ ] = size ;
710+ }
711+
712+ if (!same_disp_unit || !same_size ) {
710713
711- ret = module -> comm -> c_coll -> coll_allgather (& disp_unit , sizeof (ptrdiff_t ), MPI_BYTE ,
712- module -> disp_units , sizeof (ptrdiff_t ) , MPI_BYTE ,
713- module -> comm ,
714- module -> comm -> c_coll -> coll_allgather_module );
714+ ret = module -> comm -> c_coll -> coll_allgather (values , val_count * sizeof (long ), MPI_BYTE ,
715+ ( void * ) my_info , sizeof (long ) * val_count , MPI_BYTE ,
716+ module -> comm ,
717+ module -> comm -> c_coll -> coll_allgather_module );
715718 if (OMPI_SUCCESS != ret ) {
716719 goto error ;
717720 }
721+
722+ if (!same_disp_unit ) { /* everyone has a different disp_unit */
723+ module -> disp_unit = -1 ;
724+ module -> disp_units = calloc (comm_size , sizeof (ptrdiff_t ));
725+ if (module -> disp_units == NULL ) {
726+ ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE ;
727+ goto error ;
728+ }
729+ for (i = 0 ; i < comm_size ; i ++ ) {
730+ module -> disp_units [i ] = (ptrdiff_t )values [i * val_count ];
731+ }
732+ }
733+
734+ if (!same_size ) { /* everyone has the same disp_unit, we do not need O(p) space */
735+ module -> same_size = false;
736+ module -> sizes = calloc (comm_size , sizeof (size_t ));
737+ if (module -> sizes == NULL ) {
738+ ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE ;
739+ goto error ;
740+ }
741+
742+ for (i = 0 ; i < comm_size ; i ++ ) {
743+ module -> sizes [i ] = (size_t )values [i * val_count + val_count - 1 ];
744+ }
745+ }
746+
718747 }
719748
720749 ret = opal_common_ucx_wpctx_create (mca_osc_ucx_component .wpool , comm_size ,
@@ -1261,6 +1290,9 @@ int ompi_osc_ucx_free(struct ompi_win_t *win) {
12611290 if (module -> disp_units ) {
12621291 free (module -> disp_units );
12631292 }
1293+ if (module -> sizes ) {
1294+ free (module -> sizes );
1295+ }
12641296 ompi_comm_free (& module -> comm );
12651297
12661298 free (module );
0 commit comments