@@ -934,6 +934,110 @@ static inline int ompi_osc_ucx_acc_rputget(void *stage_addr, int stage_count,
934934 return ret ;
935935}
936936
937+ static inline int ompi_osc_ucx_check_ops_and_flush (ompi_osc_ucx_module_t * module ,
938+ int target , ptrdiff_t target_disp , int target_count , struct
939+ ompi_datatype_t * target_dt , bool lock_acquired ) {
940+ ptrdiff_t target_lb , target_extent ;
941+ uint64_t base_tmp , tail_tmp ;
942+ int ret = OMPI_SUCCESS ;
943+
944+ if (module -> ctx -> num_incomplete_req_ops > ompi_osc_ucx_outstanding_ops_flush_threshold ) {
945+ ret = opal_common_ucx_ctx_flush (module -> ctx , OPAL_COMMON_UCX_SCOPE_WORKER , 0 );
946+ if (ret != OPAL_SUCCESS ) {
947+ ret = OMPI_ERROR ;
948+ return ret ;
949+ }
950+ opal_mutex_lock (& module -> ctx -> mutex );
951+ /* Check if we need to clear the list */
952+ if (ompi_osc_ucx_outstanding_ops_flush_threshold != 0
953+ && module -> ctx -> num_incomplete_req_ops == 0 ) {
954+ memset (module -> epoc_outstanding_ops_mems , 0 ,
955+ sizeof (ompi_osc_ucx_mem_ranges_t ) *
956+ ompi_osc_ucx_outstanding_ops_flush_threshold );
957+ }
958+ opal_mutex_unlock (& module -> ctx -> mutex );
959+ }
960+
961+ if (ompi_osc_ucx_outstanding_ops_flush_threshold == 0 ) {
962+ return ret ;
963+ }
964+
965+ if (!lock_acquired ) {
966+ /* We are not acquiring the acc lock as we already may have an exclusive lock
967+ * to the target window. However, in the nb acc operation, we must
968+ * preserve the atomicy of back to back calls to same target buffer
969+ * even when acc lock is not available */
970+
971+ /* Calculate the base and range of the target buffer for this call */
972+ ompi_datatype_get_true_extent (target_dt , & target_lb , & target_extent );
973+ uint64_t base = (module -> addrs [target ]) + target_disp * OSC_UCX_GET_DISP (module , target );
974+ uint64_t tail = base + target_extent * target_count ;
975+
976+ assert ((void * )base != NULL );
977+
978+ opal_mutex_lock (& module -> ctx -> mutex );
979+
980+ bool overlap = false;
981+ /* Find overlapping outstanidng acc ops to same target */
982+ size_t i ;
983+ for (i = 0 ; i < ompi_osc_ucx_outstanding_ops_flush_threshold ; i ++ ) {
984+ base_tmp = module -> epoc_outstanding_ops_mems [i ].base ;
985+ tail_tmp = module -> epoc_outstanding_ops_mems [i ].tail ;
986+ if (base_tmp == tail_tmp ) {
987+ continue ;
988+ }
989+ if (!(tail_tmp < base || tail < base_tmp )) {
990+ overlap = true;
991+ break ;
992+ }
993+ }
994+
995+ /* If there are overlaps, then flush */
996+ if (overlap ) {
997+ ret = opal_common_ucx_ctx_flush (module -> ctx , OPAL_COMMON_UCX_SCOPE_WORKER , 0 );
998+ if (ret != OPAL_SUCCESS ) {
999+ ret = OMPI_ERROR ;
1000+ return ret ;
1001+ }
1002+ }
1003+
1004+ /* Add the new base and tail to the list of outstanding
1005+ * ops of this epoc */
1006+ bool op_added = false;
1007+ while (!op_added ) {
1008+ /* Check if we need to clear the list */
1009+ if (module -> ctx -> num_incomplete_req_ops == 0 ) {
1010+ memset (module -> epoc_outstanding_ops_mems , 0 ,
1011+ sizeof (ompi_osc_ucx_mem_ranges_t ) *
1012+ ompi_osc_ucx_outstanding_ops_flush_threshold );
1013+ }
1014+
1015+ for (i = 0 ; i < ompi_osc_ucx_outstanding_ops_flush_threshold ; i ++ ) {
1016+ base_tmp = module -> epoc_outstanding_ops_mems [i ].base ;
1017+ tail_tmp = module -> epoc_outstanding_ops_mems [i ].tail ;
1018+ if (base_tmp == tail_tmp ) {
1019+ module -> epoc_outstanding_ops_mems [i ].base = base ;
1020+ module -> epoc_outstanding_ops_mems [i ].tail = tail ;
1021+ op_added = true;
1022+ break ;
1023+ };
1024+ }
1025+
1026+ if (!op_added ) {
1027+ /* no more space so flush */
1028+ ret = opal_common_ucx_ctx_flush (module -> ctx , OPAL_COMMON_UCX_SCOPE_WORKER , 0 );
1029+ if (ret != OPAL_SUCCESS ) {
1030+ ret = OMPI_ERROR ;
1031+ return ret ;
1032+ }
1033+ }
1034+ }
1035+ opal_mutex_unlock (& module -> ctx -> mutex );
1036+ }
1037+
1038+ return ret ;
1039+ }
1040+
9371041/* Nonblocking variant of accumulate. reduce+put happens inside completion call back
9381042 * of rget */
9391043static int ompi_osc_ucx_get_accumulate_nonblocking (const void * origin_addr , int origin_count ,
@@ -971,12 +1075,10 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
9711075 return ret ;
9721076 }
9731077
974- if (module -> ctx -> num_incomplete_req_ops > ompi_osc_ucx_outstanding_ops_flush_threshold ) {
975- ret = opal_common_ucx_ctx_flush (module -> ctx , OPAL_COMMON_UCX_SCOPE_WORKER , 0 );
976- if (ret != OPAL_SUCCESS ) {
977- ret = OMPI_ERROR ;
978- return ret ;
979- }
1078+ ret = ompi_osc_ucx_check_ops_and_flush (module , target , target_disp , target_count ,
1079+ target_dt , lock_acquired );
1080+ if (ret != OMPI_SUCCESS ) {
1081+ return ret ;
9801082 }
9811083
9821084 CHECK_DYNAMIC_WIN (remote_addr , module , target , ret );
0 commit comments