66 * Copyright (c) 2014-2015 NVIDIA Corporation. All rights reserved.
77 * Copyright (c) 2022 Amazon.com, Inc. or its affiliates. All Rights reserved.
88 * Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
9+ * Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights reserved.
910 * $COPYRIGHT$
1011 *
1112 * Additional copyrights may follow
@@ -39,12 +40,13 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
3940 int rank = ompi_comm_rank (comm );
4041 ptrdiff_t gap ;
4142 char * rbuf1 = NULL , * sbuf1 = NULL , * rbuf2 = NULL ;
43+ int rbuf_dev , sbuf_dev ;
4244 size_t bufsize ;
4345 int rc ;
4446
4547 bufsize = opal_datatype_span (& dtype -> super , count , & gap );
4648
47- rc = mca_coll_accelerator_check_buf ((void * )sbuf );
49+ rc = mca_coll_accelerator_check_buf ((void * )sbuf , & sbuf_dev );
4850 if (rc < 0 ) {
4951 return rc ;
5052 }
@@ -53,11 +55,12 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
5355 if (NULL == sbuf1 ) {
5456 return OMPI_ERR_OUT_OF_RESOURCE ;
5557 }
56- mca_coll_accelerator_memcpy (sbuf1 , sbuf , bufsize );
58+ mca_coll_accelerator_memcpy (sbuf1 , MCA_ACCELERATOR_NO_DEVICE_ID , sbuf , sbuf_dev , bufsize ,
59+ MCA_ACCELERATOR_TRANSFER_DTOH );
5760 sbuf = sbuf1 - gap ;
5861 }
5962
60- rc = mca_coll_accelerator_check_buf (rbuf );
63+ rc = mca_coll_accelerator_check_buf (rbuf , & rbuf_dev );
6164 if (rc < 0 ) {
6265 return rc ;
6366 }
@@ -67,7 +70,8 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
6770 if (NULL != sbuf1 ) free (sbuf1 );
6871 return OMPI_ERR_OUT_OF_RESOURCE ;
6972 }
70- mca_coll_accelerator_memcpy (rbuf1 , rbuf , bufsize );
73+ mca_coll_accelerator_memcpy (rbuf1 , MCA_ACCELERATOR_NO_DEVICE_ID , rbuf , rbuf_dev , bufsize ,
74+ MCA_ACCELERATOR_TRANSFER_DTOH );
7175 rbuf2 = rbuf ; /* save away original buffer */
7276 rbuf = rbuf1 - gap ;
7377 }
@@ -80,7 +84,8 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
8084 }
8185 if (NULL != rbuf1 ) {
8286 rbuf = rbuf2 ;
83- mca_coll_accelerator_memcpy (rbuf , rbuf1 , bufsize );
87+ mca_coll_accelerator_memcpy (rbuf , rbuf_dev , rbuf1 , MCA_ACCELERATOR_NO_DEVICE_ID , bufsize ,
88+ MCA_ACCELERATOR_TRANSFER_HTOD );
8489 free (rbuf1 );
8590 }
8691 return rc ;
@@ -94,12 +99,13 @@ mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count,
9499{
95100 ptrdiff_t gap ;
96101 char * rbuf1 = NULL , * sbuf1 = NULL , * rbuf2 = NULL ;
102+ int sbuf_dev , rbuf_dev ;
97103 size_t bufsize ;
98104 int rc ;
99105
100106 bufsize = opal_datatype_span (& dtype -> super , count , & gap );
101107
102- rc = mca_coll_accelerator_check_buf ((void * )sbuf );
108+ rc = mca_coll_accelerator_check_buf ((void * )sbuf , & sbuf_dev );
103109 if (rc < 0 ) {
104110 return rc ;
105111 }
@@ -109,11 +115,12 @@ mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count,
109115 if (NULL == sbuf1 ) {
110116 return OMPI_ERR_OUT_OF_RESOURCE ;
111117 }
112- mca_coll_accelerator_memcpy (sbuf1 , sbuf , bufsize );
118+ mca_coll_accelerator_memcpy (sbuf1 , MCA_ACCELERATOR_NO_DEVICE_ID , sbuf , sbuf_dev , bufsize ,
119+ MCA_ACCELERATOR_TRANSFER_DTOH );
113120 sbuf = sbuf1 - gap ;
114121 }
115122
116- rc = mca_coll_accelerator_check_buf (rbuf );
123+ rc = mca_coll_accelerator_check_buf (rbuf , & rbuf_dev );
117124 if (rc < 0 ) {
118125 return rc ;
119126 }
@@ -124,7 +131,8 @@ mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count,
124131 if (NULL != sbuf1 ) free (sbuf1 );
125132 return OMPI_ERR_OUT_OF_RESOURCE ;
126133 }
127- mca_coll_accelerator_memcpy (rbuf1 , rbuf , bufsize );
134+ mca_coll_accelerator_memcpy (rbuf1 , MCA_ACCELERATOR_NO_DEVICE_ID , rbuf , rbuf_dev , bufsize ,
135+ MCA_ACCELERATOR_TRANSFER_DTOH );
128136 rbuf2 = rbuf ; /* save away original buffer */
129137 rbuf = rbuf1 - gap ;
130138 }
@@ -137,7 +145,8 @@ mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count,
137145 }
138146 if (NULL != rbuf1 ) {
139147 rbuf = rbuf2 ;
140- mca_coll_accelerator_memcpy (rbuf , rbuf1 , bufsize );
148+ mca_coll_accelerator_memcpy (rbuf , rbuf_dev , rbuf1 , MCA_ACCELERATOR_NO_DEVICE_ID , bufsize ,
149+ MCA_ACCELERATOR_TRANSFER_HTOD );
141150 free (rbuf1 );
142151 }
143152 return rc ;
0 commit comments