@@ -212,10 +212,10 @@ void merge_impl(const std::size_t offset,
212212}
213213
214214template <typename Iter, typename Compare>
215- void insertion_sort_impl (Iter first,
216- const std::size_t begin,
217- const std::size_t end,
218- Compare comp)
215+ void insertion_sort_impl (Iter && first,
216+ std::size_t begin,
217+ std::size_t end,
218+ Compare && comp)
219219{
220220 for (std::size_t i = begin + 1 ; i < end; ++i) {
221221 const auto val_i = first[i];
@@ -231,31 +231,14 @@ void insertion_sort_impl(Iter first,
231231}
232232
233233template <typename Iter, typename Compare>
234- void bubble_sort_impl (Iter first,
235- const std::size_t begin,
236- const std::size_t end,
237- Compare comp)
234+ void leaf_sort_impl (Iter && first,
235+ std::size_t begin,
236+ std::size_t end,
237+ Compare && comp)
238238{
239- if (begin < end) {
240- for (std::size_t i = begin; i < end; ++i) {
241- // Handle intermediate items
242- for (std::size_t idx = i + 1 ; idx < end; ++idx) {
243- if (comp (first[idx], first[i])) {
244- std::swap (first[i], first[idx]);
245- }
246- }
247- }
248- }
249- }
250-
251- template <typename Iter, typename Compare>
252- void leaf_sort_impl (Iter first,
253- const std::size_t begin,
254- const std::size_t end,
255- Compare comp)
256- {
257- return insertion_sort_impl<Iter, Compare>(
258- std::move (first), std::move (begin), std::move (end), std::move (comp));
239+ return insertion_sort_impl<Iter, Compare>(std::forward<Iter>(first),
240+ std::move (begin), std::move (end),
241+ std::forward<Compare>(comp));
259242}
260243
261244template <typename Iter> struct GetValueType
@@ -356,7 +339,7 @@ sort_base_step_contig_impl(sycl::queue &q,
356339 using KernelName = sort_base_step_contig_krn<inpT, outT, Comp>;
357340
358341 const std::size_t n_segments =
359- quotient_ceil<std:: size_t > (sort_nelems, conseq_nelems_sorted);
342+ quotient_ceil (sort_nelems, conseq_nelems_sorted);
360343
361344 sycl::event base_sort = q.submit ([&](sycl::handler &cgh) {
362345 cgh.depends_on (depends);
@@ -375,8 +358,7 @@ sort_base_step_contig_impl(sycl::queue &q,
375358 iter_offset + segment_id * conseq_nelems_sorted;
376359 const std::size_t end_id =
377360 iter_offset +
378- std::min<std::size_t >((segment_id + 1 ) * conseq_nelems_sorted,
379- sort_nelems);
361+ std::min ((segment_id + 1 ) * conseq_nelems_sorted, sort_nelems);
380362 for (std::size_t i = beg_id; i < end_id; ++i) {
381363 output_acc[i] = input_acc[i];
382364 }
@@ -444,8 +426,7 @@ sort_over_work_group_contig_impl(sycl::queue &q,
444426 // This assumption permits doing away with using a loop
445427 assert (nelems_wg_sorts % lws == 0 );
446428
447- const std::size_t n_segments =
448- quotient_ceil<std::size_t >(sort_nelems, nelems_wg_sorts);
429+ const std::size_t n_segments = quotient_ceil (sort_nelems, nelems_wg_sorts);
449430
450431 sycl::event base_sort_ev = q.submit ([&](sycl::handler &cgh) {
451432 cgh.depends_on (depends);
@@ -471,8 +452,8 @@ sort_over_work_group_contig_impl(sycl::queue &q,
471452 const std::size_t lid = it.get_local_linear_id ();
472453
473454 const std::size_t segment_start_idx = segment_id * nelems_wg_sorts;
474- const std::size_t segment_end_idx = std::min<std:: size_t >(
475- segment_start_idx + nelems_wg_sorts, sort_nelems);
455+ const std::size_t segment_end_idx =
456+ std::min ( segment_start_idx + nelems_wg_sorts, sort_nelems);
476457 const std::size_t wg_chunk_size =
477458 segment_end_idx - segment_start_idx;
478459
@@ -487,8 +468,7 @@ sort_over_work_group_contig_impl(sycl::queue &q,
487468 }
488469 sycl::group_barrier (it.get_group ());
489470
490- const std::size_t chunk =
491- quotient_ceil<std::size_t >(nelems_wg_sorts, lws);
471+ const std::size_t chunk = quotient_ceil (nelems_wg_sorts, lws);
492472
493473 const std::size_t chunk_start_idx = lid * chunk;
494474 const std::size_t chunk_end_idx =
@@ -620,8 +600,7 @@ merge_sorted_block_contig_impl(sycl::queue &q,
620600 used_depends = true ;
621601 }
622602
623- const std::size_t n_chunks =
624- quotient_ceil<std::size_t >(sort_nelems, chunk_size);
603+ const std::size_t n_chunks = quotient_ceil (sort_nelems, chunk_size);
625604
626605 if (needs_copy) {
627606 sycl::accessor temp_acc{temp_buf, cgh, sycl::write_only,
@@ -835,6 +814,11 @@ sycl::event stable_argsort_axis1_contig_impl(
835814 exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size,
836815 {base_sort_ev});
837816
817+ // no need to map back if iter_nelems == 1
818+ if (iter_nelems == 1u ) {
819+ return merges_ev;
820+ }
821+
838822 using MapBackKernelName = index_map_to_rows_krn<argTy, IndexTy>;
839823 using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;
840824
0 commit comments