1818#include "oshmem/proc/proc.h"
1919#include "atomic_ucx.h"
2020
21+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
22+ /*
23+ * A static params array, for datatypes of size 4 and 8. "size >> 3" is used to
24+ * access the corresponding offset.
25+ */
26+ static ucp_request_param_t mca_spml_ucp_request_params [] = {
27+ {.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE , .datatype = ucp_dt_make_contig (4 )},
28+ {.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE , .datatype = ucp_dt_make_contig (8 )}
29+ };
30+ #endif
31+
2132/*
2233 * Initial query function that is invoked during initialization, allowing
2334 * this module to indicate what level of thread support it provides.
@@ -38,20 +49,37 @@ int mca_atomic_ucx_op(shmem_ctx_t ctx,
3849 uint64_t value ,
3950 size_t size ,
4051 int pe ,
52+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
53+ ucp_atomic_op_t op )
54+ #else
4155 ucp_atomic_post_op_t op )
56+ #endif
4257{
4358 ucs_status_t status ;
4459 spml_ucx_mkey_t * ucx_mkey ;
4560 uint64_t rva ;
4661 mca_spml_ucx_ctx_t * ucx_ctx = (mca_spml_ucx_ctx_t * )ctx ;
62+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
63+ ucs_status_ptr_t status_ptr ;
64+ #endif
4765
4866 assert ((8 == size ) || (4 == size ));
4967
5068 ucx_mkey = mca_spml_ucx_get_mkey (ctx , pe , target , (void * )& rva , mca_spml_self );
69+
70+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
71+ status_ptr = ucp_atomic_op_nbx (ucx_ctx -> ucp_peers [pe ].ucp_conn ,
72+ op , & value , 1 , rva , ucx_mkey -> rkey ,
73+ & mca_spml_ucp_request_params [size >> 3 ]);
74+ if (OPAL_LIKELY (!UCS_PTR_IS_ERR (status_ptr ))) {
75+ mca_spml_ucx_remote_op_posted (ucx_ctx , pe );
76+ }
77+ status = UCS_PTR_STATUS (status_ptr );
78+ #else
5179 status = ucp_atomic_post (ucx_ctx -> ucp_peers [pe ].ucp_conn ,
5280 op , value , size , rva ,
5381 ucx_mkey -> rkey );
54-
82+ #endif
5583 if (OPAL_LIKELY (UCS_OK == status )) {
5684 mca_spml_ucx_remote_op_posted (ucx_ctx , pe );
5785 }
@@ -66,22 +94,41 @@ int mca_atomic_ucx_fop(shmem_ctx_t ctx,
6694 uint64_t value ,
6795 size_t size ,
6896 int pe ,
97+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
98+ ucp_atomic_op_t op )
99+ #else
69100 ucp_atomic_fetch_op_t op )
101+ #endif
70102{
71103 ucs_status_ptr_t status_ptr ;
72104 spml_ucx_mkey_t * ucx_mkey ;
73105 uint64_t rva ;
74106 mca_spml_ucx_ctx_t * ucx_ctx = (mca_spml_ucx_ctx_t * )ctx ;
107+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
108+ ucp_request_param_t param = {
109+ .op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE |
110+ UCP_OP_ATTR_FIELD_REPLY_BUFFER ,
111+ .datatype = ucp_dt_make_contig (size ),
112+ .reply_buffer = prev
113+ };
114+ #endif
75115
76116 assert ((8 == size ) || (4 == size ));
77117
78118 ucx_mkey = mca_spml_ucx_get_mkey (ctx , pe , target , (void * )& rva , mca_spml_self );
119+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
120+ status_ptr = ucp_atomic_op_nbx (ucx_ctx -> ucp_peers [pe ].ucp_conn , op , & value , 1 ,
121+ rva , ucx_mkey -> rkey , & param );
122+ return opal_common_ucx_wait_request (status_ptr , ucx_ctx -> ucp_worker [0 ],
123+ "ucp_atomic_op_nbx" );
124+ #else
79125 status_ptr = ucp_atomic_fetch_nb (ucx_ctx -> ucp_peers [pe ].ucp_conn ,
80126 op , value , prev , size ,
81127 rva , ucx_mkey -> rkey ,
82128 opal_common_ucx_empty_complete_cb );
83129 return opal_common_ucx_wait_request (status_ptr , ucx_ctx -> ucp_worker [0 ],
84130 "ucp_atomic_fetch_nb" );
131+ #endif
85132}
86133
87134static int mca_atomic_ucx_add (shmem_ctx_t ctx ,
@@ -90,7 +137,11 @@ static int mca_atomic_ucx_add(shmem_ctx_t ctx,
90137 size_t size ,
91138 int pe )
92139{
140+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
141+ return mca_atomic_ucx_op (ctx , target , value , size , pe , UCP_ATOMIC_OP_ADD );
142+ #else
93143 return mca_atomic_ucx_op (ctx , target , value , size , pe , UCP_ATOMIC_POST_OP_ADD );
144+ #endif
94145}
95146
96147static int mca_atomic_ucx_and (shmem_ctx_t ctx ,
@@ -99,7 +150,9 @@ static int mca_atomic_ucx_and(shmem_ctx_t ctx,
99150 size_t size ,
100151 int pe )
101152{
102- #if HAVE_DECL_UCP_ATOMIC_POST_OP_AND
153+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
154+ return mca_atomic_ucx_op (ctx , target , value , size , pe , UCP_ATOMIC_OP_AND );
155+ #elif HAVE_DECL_UCP_ATOMIC_POST_OP_AND
103156 return mca_atomic_ucx_op (ctx , target , value , size , pe , UCP_ATOMIC_POST_OP_AND );
104157#else
105158 return OSHMEM_ERR_NOT_IMPLEMENTED ;
@@ -112,7 +165,9 @@ static int mca_atomic_ucx_or(shmem_ctx_t ctx,
112165 size_t size ,
113166 int pe )
114167{
115- #if HAVE_DECL_UCP_ATOMIC_POST_OP_OR
168+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
169+ return mca_atomic_ucx_op (ctx , target , value , size , pe , UCP_ATOMIC_OP_OR );
170+ #elif HAVE_DECL_UCP_ATOMIC_POST_OP_OR
116171 return mca_atomic_ucx_op (ctx , target , value , size , pe , UCP_ATOMIC_POST_OP_OR );
117172#else
118173 return OSHMEM_ERR_NOT_IMPLEMENTED ;
@@ -125,7 +180,9 @@ static int mca_atomic_ucx_xor(shmem_ctx_t ctx,
125180 size_t size ,
126181 int pe )
127182{
128- #if HAVE_DECL_UCP_ATOMIC_POST_OP_XOR
183+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
184+ return mca_atomic_ucx_op (ctx , target , value , size , pe , UCP_ATOMIC_OP_XOR );
185+ #elif HAVE_DECL_UCP_ATOMIC_POST_OP_XOR
129186 return mca_atomic_ucx_op (ctx , target , value , size , pe , UCP_ATOMIC_POST_OP_XOR );
130187#else
131188 return OSHMEM_ERR_NOT_IMPLEMENTED ;
@@ -139,7 +196,11 @@ static int mca_atomic_ucx_fadd(shmem_ctx_t ctx,
139196 size_t size ,
140197 int pe )
141198{
199+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
200+ return mca_atomic_ucx_fop (ctx , target , prev , value , size , pe , UCP_ATOMIC_OP_ADD );
201+ #else
142202 return mca_atomic_ucx_fop (ctx , target , prev , value , size , pe , UCP_ATOMIC_FETCH_OP_FADD );
203+ #endif
143204}
144205
145206static int mca_atomic_ucx_fand (shmem_ctx_t ctx ,
@@ -149,7 +210,9 @@ static int mca_atomic_ucx_fand(shmem_ctx_t ctx,
149210 size_t size ,
150211 int pe )
151212{
152- #if HAVE_DECL_UCP_ATOMIC_FETCH_OP_FAND
213+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
214+ return mca_atomic_ucx_fop (ctx , target , prev , value , size , pe , UCP_ATOMIC_OP_AND );
215+ #elif HAVE_DECL_UCP_ATOMIC_FETCH_OP_FAND
153216 return mca_atomic_ucx_fop (ctx , target , prev , value , size , pe , UCP_ATOMIC_FETCH_OP_FAND );
154217#else
155218 return OSHMEM_ERR_NOT_IMPLEMENTED ;
@@ -163,7 +226,9 @@ static int mca_atomic_ucx_for(shmem_ctx_t ctx,
163226 size_t size ,
164227 int pe )
165228{
166- #if HAVE_DECL_UCP_ATOMIC_FETCH_OP_FOR
229+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
230+ return mca_atomic_ucx_fop (ctx , target , prev , value , size , pe , UCP_ATOMIC_OP_OR );
231+ #elif HAVE_DECL_UCP_ATOMIC_FETCH_OP_FOR
167232 return mca_atomic_ucx_fop (ctx , target , prev , value , size , pe , UCP_ATOMIC_FETCH_OP_FOR );
168233#else
169234 return OSHMEM_ERR_NOT_IMPLEMENTED ;
@@ -177,7 +242,9 @@ static int mca_atomic_ucx_fxor(shmem_ctx_t ctx,
177242 size_t size ,
178243 int pe )
179244{
180- #if HAVE_DECL_UCP_ATOMIC_FETCH_OP_FXOR
245+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
246+ return mca_atomic_ucx_fop (ctx , target , prev , value , size , pe , UCP_ATOMIC_OP_XOR );
247+ #elif HAVE_DECL_UCP_ATOMIC_FETCH_OP_FXOR
181248 return mca_atomic_ucx_fop (ctx , target , prev , value , size , pe , UCP_ATOMIC_FETCH_OP_FXOR );
182249#else
183250 return OSHMEM_ERR_NOT_IMPLEMENTED ;
@@ -191,7 +258,11 @@ static int mca_atomic_ucx_swap(shmem_ctx_t ctx,
191258 size_t size ,
192259 int pe )
193260{
261+ #if HAVE_DECL_UCP_ATOMIC_OP_NBX
262+ return mca_atomic_ucx_fop (ctx , target , prev , value , size , pe , UCP_ATOMIC_OP_SWAP );
263+ #else
194264 return mca_atomic_ucx_fop (ctx , target , prev , value , size , pe , UCP_ATOMIC_FETCH_OP_SWAP );
265+ #endif
195266}
196267
197268
0 commit comments