88 * Copyright (c) 2019 Triad National Security, LLC. All rights
99 * reserved.
1010 * Copyright (c) 2019-2021 Google, LLC. All rights reserved.
11+ * Copyright (c) 2021 IBM Corporation. All rights reserved.
1112 * $COPYRIGHT$
1213 *
1314 * Additional copyrights may follow
1920#include "osc_rdma_request.h"
2021#include "osc_rdma_comm.h"
2122
23+ #include "ompi/mca/osc/base/base.h"
2224#include "ompi/mca/osc/base/osc_base_obj_convert.h"
2325
2426static inline void ompi_osc_rdma_peer_accumulate_cleanup (ompi_osc_rdma_module_t * module , ompi_osc_rdma_peer_t * peer , bool lock_acquired )
@@ -133,6 +135,22 @@ static int ompi_osc_rdma_op_mapping[OMPI_OP_NUM_OF_TYPES + 1] = {
133135 [OMPI_OP_REPLACE ] = MCA_BTL_ATOMIC_SWAP ,
134136};
135137
138+ /* set the appropriate flags for this atomic */
139+ static inline int ompi_osc_rdma_set_btl_flags (ompi_osc_rdma_module_t * module , ompi_datatype_t * dt , ptrdiff_t extent ) {
140+
141+ int flags = 0 ;
142+
143+ if (4 == extent ) {
144+ flags = MCA_BTL_ATOMIC_FLAG_32BIT ;
145+ }
146+
147+ if (OMPI_DATATYPE_FLAG_DATA_FLOAT & dt -> super .flags ) {
148+ flags |= MCA_BTL_ATOMIC_FLAG_FLOAT ;
149+ }
150+
151+ return flags ;
152+ }
153+
136154static int ompi_osc_rdma_fetch_and_op_atomic (ompi_osc_rdma_sync_t * sync , const void * origin_addr , void * result_addr , ompi_datatype_t * dt ,
137155 ptrdiff_t extent , ompi_osc_rdma_peer_t * peer , uint64_t target_address ,
138156 mca_btl_base_registration_handle_t * target_handle , ompi_op_t * op , ompi_osc_rdma_request_t * req )
@@ -151,10 +169,7 @@ static int ompi_osc_rdma_fetch_and_op_atomic (ompi_osc_rdma_sync_t *sync, const
151169
152170 btl_op = ompi_osc_rdma_op_mapping [op -> op_type ];
153171
154- flags = (4 == extent ) ? MCA_BTL_ATOMIC_FLAG_32BIT : 0 ;
155- if (OMPI_DATATYPE_FLAG_DATA_FLOAT & dt -> super .flags ) {
156- flags |= MCA_BTL_ATOMIC_FLAG_FLOAT ;
157- }
172+ flags = ompi_osc_rdma_set_btl_flags (module , dt , extent );
158173
159174 OSC_RDMA_VERBOSE (MCA_BASE_VERBOSE_TRACE , "initiating fetch-and-op using %d-bit btl atomics. origin: 0x%" PRIx64 ,
160175 (4 == extent ) ? 32 : 64 , * ((int64_t * ) origin_addr ));
@@ -239,10 +254,7 @@ static int ompi_osc_rdma_acc_single_atomic (ompi_osc_rdma_sync_t *sync, const vo
239254 origin = (8 == extent ) ? ((uint64_t * ) origin_addr )[0 ] : ((uint32_t * ) origin_addr )[0 ];
240255
241256 /* set the appropriate flags for this atomic */
242- flags = (4 == extent ) ? MCA_BTL_ATOMIC_FLAG_32BIT : 0 ;
243- if (OMPI_DATATYPE_FLAG_DATA_FLOAT & dt -> super .flags ) {
244- flags |= MCA_BTL_ATOMIC_FLAG_FLOAT ;
245- }
257+ flags = ompi_osc_rdma_set_btl_flags (module , dt , extent );
246258
247259 btl_op = ompi_osc_rdma_op_mapping [op -> op_type ];
248260
@@ -328,19 +340,21 @@ static inline int ompi_osc_rdma_gacc_contig (ompi_osc_rdma_sync_t *sync, const v
328340 ompi_datatype_t * target_datatype , ompi_op_t * op , ompi_osc_rdma_request_t * request )
329341{
330342 ompi_osc_rdma_module_t * module = sync -> module ;
331- unsigned long len = target_count * target_datatype -> super .size ;
343+ size_t target_dtype_size = target_datatype -> super .size ;
344+ unsigned long len = target_count * target_dtype_size ;
332345 char * ptr = NULL ;
333346 int ret ;
334347
335- request -> len = target_datatype -> super . size * module -> network_amo_max_count ;
348+ request -> len = target_dtype_size * module -> network_amo_max_count ;
336349
337350 OSC_RDMA_VERBOSE (MCA_BASE_VERBOSE_TRACE , "initiating accumulate on contiguous region of %lu bytes to remote address %" PRIx64
338351 ", sync %p" , len , target_address , (void * ) sync );
339352
340353 /* if the datatype is small enough (and the count is 1) then try to directly use the hardware to execute
341354 * the atomic operation. this should be safe in all cases as either 1) the user has assured us they will
342355 * never use atomics with count > 1, 2) we have the accumulate lock, or 3) we have an exclusive lock */
343- if ((target_datatype -> super .size <= 8 ) && (((unsigned long ) target_count ) <= module -> network_amo_max_count )) {
356+ if ((target_dtype_size <= 8 ) && (((unsigned long ) target_count ) <= module -> network_amo_max_count ) &&
357+ ompi_osc_base_is_atomic_size_supported (target_address , target_dtype_size )) {
344358 ret = ompi_osc_rdma_gacc_amo (module , sync , source , result , result_count , result_datatype , result_convertor ,
345359 peer , target_address , target_handle , target_count , target_datatype , op , request );
346360 if (OPAL_LIKELY (OMPI_SUCCESS == ret )) {
@@ -659,7 +673,8 @@ static inline int ompi_osc_rdma_cas_atomic (ompi_osc_rdma_sync_t *sync, const vo
659673
660674 compare = (8 == size ) ? ((int64_t * ) compare_addr )[0 ] : ((int32_t * ) compare_addr )[0 ];
661675 source = (8 == size ) ? ((int64_t * ) source_addr )[0 ] : ((int32_t * ) source_addr )[0 ];
662- flags = (4 == size ) ? MCA_BTL_ATOMIC_FLAG_32BIT : 0 ;
676+
677+ flags = ompi_osc_rdma_set_btl_flags (module , datatype , size );
663678
664679 OSC_RDMA_VERBOSE (MCA_BASE_VERBOSE_TRACE , "initiating compare-and-swap using %d-bit btl atomics. compare: 0x%"
665680 PRIx64 ", origin: 0x%" PRIx64 , (int ) size * 8 , * ((int64_t * ) compare_addr ), * ((int64_t * ) source_addr ));
@@ -830,10 +845,12 @@ int ompi_osc_rdma_compare_and_swap (const void *origin_addr, const void *compare
830845 * user has indicated that they will only use the same op (or same op and no op) for
831846 * operations on overlapping memory ranges. that indicates it is safe to go ahead and
832847 * use network atomic operations. */
833- ret = ompi_osc_rdma_cas_atomic (sync , origin_addr , compare_addr , result_addr , dt ,
834- peer , target_address , target_handle , lock_acquired );
835- if (OMPI_SUCCESS == ret ) {
836- return OMPI_SUCCESS ;
848+ if (ompi_osc_base_is_atomic_size_supported (target_address , dt -> super .size )) {
849+ ret = ompi_osc_rdma_cas_atomic (sync , origin_addr , compare_addr , result_addr , dt ,
850+ peer , target_address , target_handle , lock_acquired );
851+ if (OMPI_SUCCESS == ret ) {
852+ return OMPI_SUCCESS ;
853+ }
837854 }
838855 }
839856
0 commit comments