@@ -58,47 +58,97 @@ template <typename T> struct boolean_predicate
5858 }
5959};
6060
61- template <typename inpT, typename outT, typename PredicateT, size_t wg_dim>
61+ template <typename inpT,
62+ typename outT,
63+ typename PredicateT,
64+ std::uint8_t wg_dim = 2 >
6265struct all_reduce_wg_contig
6366{
64- outT operator ()(sycl::group<wg_dim> &wg,
67+ void operator ()(sycl::nd_item<wg_dim> &ndit,
68+ outT *out,
69+ size_t &out_idx,
6570 const inpT *start,
6671 const inpT *end) const
6772 {
6873 PredicateT pred{};
69- return static_cast <outT>(sycl::joint_all_of (wg, start, end, pred));
74+ auto wg = ndit.get_group ();
75+ outT red_val_over_wg =
76+ static_cast <outT>(sycl::joint_all_of (wg, start, end, pred));
77+
78+ if (wg.leader ()) {
79+ sycl::atomic_ref<outT, sycl::memory_order::relaxed,
80+ sycl::memory_scope::device,
81+ sycl::access::address_space::global_space>
82+ res_ref (out[out_idx]);
83+ res_ref.fetch_and (red_val_over_wg);
84+ }
7085 }
7186};
7287
73- template <typename inpT, typename outT, typename PredicateT, size_t wg_dim>
88+ template <typename inpT,
89+ typename outT,
90+ typename PredicateT,
91+ std::uint8_t wg_dim = 2 >
7492struct any_reduce_wg_contig
7593{
76- outT operator ()(sycl::group<wg_dim> &wg,
94+ void operator ()(sycl::nd_item<wg_dim> &ndit,
95+ outT *out,
96+ size_t &out_idx,
7797 const inpT *start,
7898 const inpT *end) const
7999 {
80100 PredicateT pred{};
81- return static_cast <outT>(sycl::joint_any_of (wg, start, end, pred));
101+ auto wg = ndit.get_group ();
102+ outT red_val_over_wg =
103+ static_cast <outT>(sycl::joint_any_of (wg, start, end, pred));
104+
105+ if (wg.leader ()) {
106+ sycl::atomic_ref<outT, sycl::memory_order::relaxed,
107+ sycl::memory_scope::device,
108+ sycl::access::address_space::global_space>
109+ res_ref (out[out_idx]);
110+ res_ref.fetch_or (red_val_over_wg);
111+ }
82112 }
83113};
84114
85- template <typename T, typename PredicateT, size_t wg_dim>
86- struct all_reduce_wg_strided
115+ template <typename T, std::uint8_t wg_dim = 2 > struct all_reduce_wg_strided
87116{
88- T operator ()(sycl::group<wg_dim> &wg, const T &local_val) const
117+ void operator ()(sycl::nd_item<wg_dim> &ndit,
118+ T *out,
119+ const size_t &out_idx,
120+ const T &local_val) const
89121 {
90- PredicateT pred{};
91- return static_cast <T>(sycl::all_of_group (wg, local_val, pred));
122+ auto wg = ndit.get_group ();
123+ T red_val_over_wg = static_cast <T>(sycl::all_of_group (wg, local_val));
124+
125+ if (wg.leader ()) {
126+ sycl::atomic_ref<T, sycl::memory_order::relaxed,
127+ sycl::memory_scope::device,
128+ sycl::access::address_space::global_space>
129+ res_ref (out[out_idx]);
130+ res_ref.fetch_and (red_val_over_wg);
131+ }
92132 }
93133};
94134
95- template <typename T, typename PredicateT, size_t wg_dim>
96- struct any_reduce_wg_strided
135+ template <typename T, std::uint8_t wg_dim = 2 > struct any_reduce_wg_strided
97136{
98- T operator ()(sycl::group<wg_dim> &wg, const T &local_val) const
137+ void operator ()(sycl::nd_item<wg_dim> &ndit,
138+ T *out,
139+ const size_t &out_idx,
140+ const T &local_val) const
99141 {
100- PredicateT pred{};
101- return static_cast <T>(sycl::any_of_group (wg, local_val, pred));
142+ auto wg = ndit.get_group ();
143+ T red_val_over_wg = static_cast <T>(sycl::any_of_group (wg, local_val));
144+
145+ if (wg.leader ()) {
146+ sycl::atomic_ref<T, sycl::memory_order::relaxed,
147+ sycl::memory_scope::device,
148+ sycl::access::address_space::global_space>
149+ res_ref (out[out_idx]);
150+ res_ref.fetch_or (red_val_over_wg);
151+ }
102152 }
103153};
104154
@@ -137,8 +187,10 @@ struct SequentialBooleanReduction
137187 {
138188
139189 auto inp_out_iter_offsets_ = inp_out_iter_indexer_ (id[0 ]);
140- const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset ();
141- const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset ();
190+ const size_t &inp_iter_offset =
191+ inp_out_iter_offsets_.get_first_offset ();
192+ const size_t &out_iter_offset =
193+ inp_out_iter_offsets_.get_second_offset ();
142194
143195 outT red_val (identity_);
144196 for (size_t m = 0 ; m < reduction_max_gid_; ++m) {
@@ -156,26 +208,24 @@ struct SequentialBooleanReduction
156208 }
157209};
158210
159- template <typename argT, typename outT, typename ReductionOp, typename GroupOp>
211+ template <typename argT, typename outT, typename GroupOp>
160212struct ContigBooleanReduction
161213{
162214private:
163215 const argT *inp_ = nullptr ;
164216 outT *out_ = nullptr ;
165- ReductionOp reduction_op_;
166217 GroupOp group_op_;
167218 size_t reduction_max_gid_ = 0 ;
168219 size_t reductions_per_wi = 16 ;
169220
170221public:
171222 ContigBooleanReduction (const argT *inp,
172223 outT *res,
173- ReductionOp reduction_op,
174224 GroupOp group_op,
175225 size_t reduction_size,
176226 size_t reduction_size_per_wi)
177- : inp_(inp), out_(res), reduction_op_(reduction_op ),
178- group_op_ (group_op), reduction_max_gid_(reduction_size),
227+ : inp_(inp), out_(res), group_op_(group_op ),
228+ reduction_max_gid_ (reduction_size),
179229 reductions_per_wi(reduction_size_per_wi)
180230 {
181231 }
@@ -185,30 +235,15 @@ struct ContigBooleanReduction
185235
186236 size_t reduction_id = it.get_group (0 );
187237 size_t reduction_batch_id = it.get_group (1 );
188-
189- auto work_group = it.get_group ();
190238 size_t wg_size = it.get_local_range (1 );
191239
192240 size_t base = reduction_id * reduction_max_gid_;
193241 size_t start = base + reduction_batch_id * wg_size * reductions_per_wi;
194242 size_t end = std::min ((start + (reductions_per_wi * wg_size)),
195243 base + reduction_max_gid_);
196-
197- // reduction to the work group level is performed
198- // inside group_op
199- outT red_val_over_wg = group_op_ (work_group, inp_ + start, inp_ + end);
200-
201- if (work_group.leader ()) {
202- sycl::atomic_ref<outT, sycl::memory_order::relaxed,
203- sycl::memory_scope::device,
204- sycl::access::address_space::global_space>
205- res_ref (out_[reduction_id]);
206- outT read_val = res_ref.load ();
207- outT new_val{};
208- do {
209- new_val = reduction_op_ (read_val, red_val_over_wg);
210- } while (!res_ref.compare_exchange_strong (read_val, new_val));
211- }
244+ // reduction and atomic operations are performed
245+ // in group_op_
246+ group_op_ (it, out_, reduction_id, inp_ + start, inp_ + end);
212247 }
213248};
214249
@@ -223,7 +258,7 @@ typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)(
223258 py::ssize_t ,
224259 const std::vector<sycl::event> &);
225260
226- template <typename T1, typename T2, typename T3, typename T4 >
261+ template <typename T1, typename T2, typename T3>
227262class boolean_reduction_contig_krn ;
228263
229264template <typename T1, typename T2, typename T3, typename T4, typename T5>
@@ -298,7 +333,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
298333 red_ev = exec_q.submit ([&](sycl::handler &cgh) {
299334 cgh.depends_on (init_ev);
300335
301- constexpr size_t group_dim = 2 ;
336+ constexpr std:: uint8_t group_dim = 2 ;
302337
303338 constexpr size_t preferred_reductions_per_wi = 4 ;
304339 size_t reductions_per_wi =
@@ -314,11 +349,11 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
314349 sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
315350 auto lws = sycl::range<group_dim>{1 , wg};
316351
317- cgh.parallel_for <class boolean_reduction_contig_krn <
318- argTy, resTy, RedOpT , GroupOpT>>(
352+ cgh.parallel_for <
353+ class boolean_reduction_contig_krn < argTy, resTy, GroupOpT>>(
319354 sycl::nd_range<group_dim>(gws, lws),
320- ContigBooleanReduction<argTy, resTy, RedOpT, GroupOpT>(
321- arg_tp, res_tp, RedOpT (), GroupOpT (), reduction_nelems,
355+ ContigBooleanReduction<argTy, resTy, GroupOpT>(
356+ arg_tp, res_tp, GroupOpT (), reduction_nelems,
322357 reductions_per_wi));
323358 });
324359 }
@@ -332,7 +367,7 @@ template <typename fnT, typename srcTy> struct AllContigFactory
332367 using resTy = std::int32_t ;
333368 using RedOpT = sycl::logical_and<resTy>;
334369 using GroupOpT =
335- all_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>, 2 >;
370+ all_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>>;
336371
337372 return dpctl::tensor::kernels::boolean_reduction_contig_impl<
338373 srcTy, resTy, RedOpT, GroupOpT>;
@@ -346,7 +381,7 @@ template <typename fnT, typename srcTy> struct AnyContigFactory
346381 using resTy = std::int32_t ;
347382 using RedOpT = sycl::logical_or<resTy>;
348383 using GroupOpT =
349- any_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>, 2 >;
384+ any_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>>;
350385
351386 return dpctl::tensor::kernels::boolean_reduction_contig_impl<
352387 srcTy, resTy, RedOpT, GroupOpT>;
@@ -400,8 +435,10 @@ struct StridedBooleanReduction
400435 size_t wg_size = it.get_local_range (1 );
401436
402437 auto inp_out_iter_offsets_ = inp_out_iter_indexer_ (reduction_id);
403- const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset ();
404- const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset ();
438+ const size_t &inp_iter_offset =
439+ inp_out_iter_offsets_.get_first_offset ();
440+ const size_t &out_iter_offset =
441+ inp_out_iter_offsets_.get_second_offset ();
405442
406443 outT local_red_val (identity_);
407444 size_t arg_reduce_gid0 =
@@ -416,28 +453,15 @@ struct StridedBooleanReduction
416453
417454 // must convert to boolean first to handle nans
418455 using dpctl::tensor::type_utils::convert_impl;
419- outT val = convert_impl<bool , argT>(inp_[inp_offset]);
456+ bool val = convert_impl<bool , argT>(inp_[inp_offset]);
420457
421- local_red_val = reduction_op_ (local_red_val, val);
458+ local_red_val =
459+ reduction_op_ (local_red_val, static_cast <outT>(val));
422460 }
423461 }
424-
425- // reduction to the work group level is performed
426- // inside group_op
427- auto work_group = it.get_group ();
428- outT red_val_over_wg = group_op_ (work_group, local_red_val);
429-
430- if (work_group.leader ()) {
431- sycl::atomic_ref<outT, sycl::memory_order::relaxed,
432- sycl::memory_scope::device,
433- sycl::access::address_space::global_space>
434- res_ref (out_[out_iter_offset]);
435- outT read_val = res_ref.load ();
436- outT new_val{};
437- do {
438- new_val = reduction_op_ (read_val, red_val_over_wg);
439- } while (!res_ref.compare_exchange_strong (read_val, new_val));
440- }
462+ // reduction and atomic operations are performed
463+ // in group_op_
464+ group_op_ (it, out_, out_iter_offset, local_red_val);
441465 }
442466};
443467
@@ -541,7 +565,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
541565 red_ev = exec_q.submit ([&](sycl::handler &cgh) {
542566 cgh.depends_on (res_init_ev);
543567
544- constexpr size_t group_dim = 2 ;
568+ constexpr std:: uint8_t group_dim = 2 ;
545569
546570 using InputOutputIterIndexerT =
547571 dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
@@ -589,8 +613,7 @@ template <typename fnT, typename srcTy> struct AllStridedFactory
589613 {
590614 using resTy = std::int32_t ;
591615 using RedOpT = sycl::logical_and<resTy>;
592- using GroupOpT =
593- all_reduce_wg_strided<resTy, boolean_predicate<srcTy>, 2 >;
616+ using GroupOpT = all_reduce_wg_strided<resTy>;
594617
595618 return dpctl::tensor::kernels::boolean_reduction_strided_impl<
596619 srcTy, resTy, RedOpT, GroupOpT>;
@@ -603,8 +626,7 @@ template <typename fnT, typename srcTy> struct AnyStridedFactory
603626 {
604627 using resTy = std::int32_t ;
605628 using RedOpT = sycl::logical_or<resTy>;
606- using GroupOpT =
607- any_reduce_wg_strided<resTy, boolean_predicate<srcTy>, 2 >;
629+ using GroupOpT = any_reduce_wg_strided<resTy>;
608630
609631 return dpctl::tensor::kernels::boolean_reduction_strided_impl<
610632 srcTy, resTy, RedOpT, GroupOpT>;
0 commit comments