@@ -184,7 +184,7 @@ mca_coll_han_allreduce_intra(const void *sbuf,
184184 mca_coll_task_t * t_next_seg = OBJ_NEW (mca_coll_task_t );
185185 /* Setup up t_next_seg task arguments */
186186 t -> cur_task = t_next_seg ;
187- t -> sbuf = (char * ) t -> sbuf + extent * t -> seg_count ;
187+ t -> sbuf = (t -> sbuf == MPI_IN_PLACE ) ? MPI_IN_PLACE : ( char * ) t -> sbuf + extent * t -> seg_count ;
188188 t -> rbuf = (char * ) t -> rbuf + extent * t -> seg_count ;
189189 t -> cur_seg = t -> cur_seg + 1 ;
190190 /* Init t_next_seg task */
@@ -262,11 +262,26 @@ int mca_coll_han_allreduce_t1_task(void *task_args)
262262 if (t -> cur_seg == t -> num_segments - 2 && t -> last_seg_count != t -> seg_count ) {
263263 tmp_count = t -> last_seg_count ;
264264 }
265- t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + extent * t -> seg_count ,
266- (char * ) t -> rbuf + extent * t -> seg_count , tmp_count ,
267- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
268- t -> low_comm -> c_coll -> coll_reduce_module );
269265
266+ if (t -> sbuf == MPI_IN_PLACE ) {
267+ if (!t -> noop ) {
268+ t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE ,
269+ (char * ) t -> rbuf + extent * t -> seg_count , tmp_count ,
270+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
271+ t -> low_comm -> c_coll -> coll_reduce_module );
272+ } else {
273+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf + extent * t -> seg_count ,
274+ NULL , tmp_count ,
275+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
276+ t -> low_comm -> c_coll -> coll_reduce_module );
277+
278+ }
279+ } else {
280+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + extent * t -> seg_count ,
281+ (char * ) t -> rbuf + extent * t -> seg_count , tmp_count ,
282+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
283+ t -> low_comm -> c_coll -> coll_reduce_module );
284+ }
270285 }
271286 if (!t -> noop ) {
272287 ompi_request_wait (& ireduce_req , MPI_STATUS_IGNORE );
@@ -321,10 +336,26 @@ int mca_coll_han_allreduce_t2_task(void *task_args)
321336 if (t -> cur_seg == t -> num_segments - 3 && t -> last_seg_count != t -> seg_count ) {
322337 tmp_count = t -> last_seg_count ;
323338 }
324- t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 2 * extent * t -> seg_count ,
325- (char * ) t -> rbuf + 2 * extent * t -> seg_count , tmp_count ,
326- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
327- t -> low_comm -> c_coll -> coll_reduce_module );
339+
340+ if (t -> sbuf == MPI_IN_PLACE ) {
341+ if (!t -> noop ) {
342+ t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE ,
343+ (char * ) t -> rbuf + 2 * extent * t -> seg_count , tmp_count ,
344+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
345+ t -> low_comm -> c_coll -> coll_reduce_module );
346+ } else {
347+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf + 2 * extent * t -> seg_count ,
348+ NULL , tmp_count ,
349+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
350+ t -> low_comm -> c_coll -> coll_reduce_module );
351+
352+ }
353+ } else {
354+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 2 * extent * t -> seg_count ,
355+ (char * ) t -> rbuf + 2 * extent * t -> seg_count , tmp_count ,
356+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
357+ t -> low_comm -> c_coll -> coll_reduce_module );
358+ }
328359 }
329360 if (!t -> noop && req_count > 0 ) {
330361 ompi_request_wait_all (req_count , reqs , MPI_STATUSES_IGNORE );
@@ -385,10 +416,25 @@ int mca_coll_han_allreduce_t3_task(void *task_args)
385416 if (t -> cur_seg == t -> num_segments - 4 && t -> last_seg_count != t -> seg_count ) {
386417 tmp_count = t -> last_seg_count ;
387418 }
388- t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 3 * extent * t -> seg_count ,
389- (char * ) t -> rbuf + 3 * extent * t -> seg_count , tmp_count ,
390- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
391- t -> low_comm -> c_coll -> coll_reduce_module );
419+
420+ if (t -> sbuf == MPI_IN_PLACE ) {
421+ if (!t -> noop ) {
422+ t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE ,
423+ (char * ) t -> rbuf + 3 * extent * t -> seg_count , tmp_count ,
424+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
425+ t -> low_comm -> c_coll -> coll_reduce_module );
426+ } else {
427+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf + 3 * extent * t -> seg_count ,
428+ NULL , tmp_count ,
429+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
430+ t -> low_comm -> c_coll -> coll_reduce_module );
431+ }
432+ } else {
433+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 3 * extent * t -> seg_count ,
434+ (char * ) t -> rbuf + 3 * extent * t -> seg_count , tmp_count ,
435+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
436+ t -> low_comm -> c_coll -> coll_reduce_module );
437+ }
392438 }
393439 /* lb of cur_seg */
394440 if (t -> cur_seg == t -> num_segments - 1 && t -> last_seg_count != t -> seg_count ) {
0 commit comments