@@ -145,7 +145,7 @@ template <typename T> class stack_strided_t
145145
146146namespace su_ns = dpctl::tensor::sycl_utils;
147147
148- using nwiT = std::uint16_t ;
148+ using nwiT = std::uint32_t ;
149149
150150template <typename inputT,
151151 typename outputT,
@@ -156,7 +156,18 @@ template <typename inputT,
156156 typename TransformerT,
157157 typename ScanOpT,
158158 bool include_initial>
159- class inclusive_scan_iter_local_scan_krn ;
159+ class inclusive_scan_iter_local_scan_blocked_krn ;
160+
161+ template <typename inputT,
162+ typename outputT,
163+ nwiT n_wi,
164+ typename IterIndexerT,
165+ typename InpIndexerT,
166+ typename OutIndexerT,
167+ typename TransformerT,
168+ typename ScanOpT,
169+ bool include_initial>
170+ class inclusive_scan_iter_local_scan_striped_krn ;
160171
161172template <typename inputT,
162173 typename outputT,
@@ -177,22 +188,22 @@ template <typename inputT,
177188 typename ScanOpT,
178189 bool include_initial = false >
179190sycl::event
180- inclusive_scan_base_step (sycl::queue &exec_q,
181- const std::size_t wg_size,
182- const std::size_t iter_nelems,
183- const std::size_t acc_nelems,
184- const inputT *input,
185- outputT *output,
186- const std::size_t s0,
187- const std::size_t s1,
188- const IterIndexerT &iter_indexer,
189- const InpIndexerT &inp_indexer,
190- const OutIndexerT &out_indexer,
191- TransformerT transformer,
192- const ScanOpT &scan_op,
193- outputT identity,
194- std::size_t &acc_groups,
195- const std::vector<sycl::event> &depends = {})
191+ inclusive_scan_base_step_blocked (sycl::queue &exec_q,
192+ const std::uint32_t wg_size,
193+ const std::size_t iter_nelems,
194+ const std::size_t acc_nelems,
195+ const inputT *input,
196+ outputT *output,
197+ const std::size_t s0,
198+ const std::size_t s1,
199+ const IterIndexerT &iter_indexer,
200+ const InpIndexerT &inp_indexer,
201+ const OutIndexerT &out_indexer,
202+ TransformerT transformer,
203+ const ScanOpT &scan_op,
204+ outputT identity,
205+ std::size_t &acc_groups,
206+ const std::vector<sycl::event> &depends = {})
196207{
197208 acc_groups = ceiling_quotient<std::size_t >(acc_nelems, n_wi * wg_size);
198209
@@ -208,7 +219,7 @@ inclusive_scan_base_step(sycl::queue &exec_q,
208219
209220 slmT slm_iscan_tmp (lws, cgh);
210221
211- using KernelName = inclusive_scan_iter_local_scan_krn <
222+ using KernelName = inclusive_scan_iter_local_scan_blocked_krn <
212223 inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
213224 TransformerT, ScanOpT, include_initial>;
214225
@@ -218,6 +229,7 @@ inclusive_scan_base_step(sycl::queue &exec_q,
218229 const std::size_t gid = it.get_global_id (0 );
219230 const std::size_t lid = it.get_local_id (0 );
220231
232+ const std::uint32_t wg_size = it.get_local_range (0 );
221233 const std::size_t reduce_chunks = acc_groups * wg_size;
222234 const std::size_t iter_gid = gid / reduce_chunks;
223235 const std::size_t chunk_gid = gid - (iter_gid * reduce_chunks);
@@ -296,6 +308,248 @@ inclusive_scan_base_step(sycl::queue &exec_q,
296308 return inc_scan_phase1_ev;
297309}
298310
311+ template <typename inputT,
312+ typename outputT,
313+ nwiT n_wi,
314+ typename IterIndexerT,
315+ typename InpIndexerT,
316+ typename OutIndexerT,
317+ typename TransformerT,
318+ typename ScanOpT,
319+ bool include_initial = false >
320+ sycl::event
321+ inclusive_scan_base_step_striped (sycl::queue &exec_q,
322+ const std::uint32_t wg_size,
323+ const std::size_t iter_nelems,
324+ const std::size_t acc_nelems,
325+ const inputT *input,
326+ outputT *output,
327+ const std::size_t s0,
328+ const std::size_t s1,
329+ const IterIndexerT &iter_indexer,
330+ const InpIndexerT &inp_indexer,
331+ const OutIndexerT &out_indexer,
332+ TransformerT transformer,
333+ const ScanOpT &scan_op,
334+ outputT identity,
335+ std::size_t &acc_groups,
336+ const std::vector<sycl::event> &depends = {})
337+ {
338+ const std::uint32_t reduce_nelems_per_wg = n_wi * wg_size;
339+ acc_groups =
340+ ceiling_quotient<std::size_t >(acc_nelems, reduce_nelems_per_wg);
341+
342+ sycl::event inc_scan_phase1_ev = exec_q.submit ([&](sycl::handler &cgh) {
343+ cgh.depends_on (depends);
344+
345+ using slmT = sycl::local_accessor<outputT, 1 >;
346+
347+ const auto &gRange = sycl::range<1 >{iter_nelems * acc_groups * wg_size};
348+ const auto &lRange = sycl::range<1 >{wg_size};
349+
350+ const auto &ndRange = sycl::nd_range<1 >{gRange , lRange};
351+
352+ slmT slm_iscan_tmp (reduce_nelems_per_wg, cgh);
353+
354+ using KernelName = inclusive_scan_iter_local_scan_striped_krn<
355+ inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
356+ TransformerT, ScanOpT, include_initial>;
357+
358+ cgh.parallel_for <KernelName>(ndRange, [=, slm_iscan_tmp =
359+ std::move (slm_iscan_tmp)](
360+ sycl::nd_item<1 > it) {
361+ const std::uint32_t lid = it.get_local_linear_id ();
362+ const std::uint32_t wg_size = it.get_local_range (0 );
363+
364+ const auto &sg = it.get_sub_group ();
365+ const std::uint32_t sgSize = sg.get_max_local_range ()[0 ];
366+ const std::size_t sgroup_id = sg.get_group_id ()[0 ];
367+ const std::uint32_t lane_id = sg.get_local_id ()[0 ];
368+
369+ const std::size_t flat_group_id = it.get_group (0 );
370+ const std::size_t iter_gid = flat_group_id / acc_groups;
371+ const std::size_t acc_group_id =
372+ flat_group_id - (iter_gid * acc_groups);
373+
374+ const auto &iter_offsets = iter_indexer (iter_gid);
375+ const auto &inp_iter_offset = iter_offsets.get_first_offset ();
376+ const auto &out_iter_offset = iter_offsets.get_second_offset ();
377+
378+ std::array<outputT, n_wi> local_iscan{};
379+
380+ const std::size_t inp_id0 = acc_group_id * n_wi * wg_size +
381+ sgroup_id * n_wi * sgSize + lane_id;
382+
383+ #pragma unroll
384+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
385+ const std::size_t inp_id = inp_id0 + m_wi * sgSize;
386+ if constexpr (!include_initial) {
387+ local_iscan[m_wi] =
388+ (inp_id < acc_nelems)
389+ ? transformer (input[inp_iter_offset +
390+ inp_indexer (s0 + s1 * inp_id)])
391+ : identity;
392+ }
393+ else {
394+ // shift input to the left by a single element relative to
395+ // output
396+ local_iscan[m_wi] =
397+ (inp_id < acc_nelems && inp_id > 0 )
398+ ? transformer (
399+ input[inp_iter_offset +
400+ inp_indexer ((s0 + s1 * inp_id) - 1 )])
401+ : identity;
402+ }
403+ }
404+
405+ // change layout from striped to blocked
406+ {
407+ {
408+ const std::uint32_t local_offset0 = lid * n_wi;
409+ #pragma unroll
410+ for (std::uint32_t i = 0 ; i < n_wi; ++i) {
411+ slm_iscan_tmp[local_offset0 + i] = local_iscan[i];
412+ }
413+
414+ it.barrier (sycl::access::fence_space::local_space);
415+ }
416+
417+ {
418+ const std::uint32_t block_offset =
419+ sgroup_id * sgSize * n_wi;
420+ const std::uint32_t disp0 = lane_id * n_wi;
421+ #pragma unroll
422+ for (nwiT i = 0 ; i < n_wi; ++i) {
423+ const std::uint32_t disp = disp0 + i;
424+
425+ // disp == lane_id1 + i1 * sgSize;
426+ const std::uint32_t i1 = disp / sgSize;
427+ const std::uint32_t lane_id1 = disp - i1 * sgSize;
428+
429+ const std::uint32_t disp_exchanged =
430+ (lane_id1 * n_wi + i1);
431+
432+ local_iscan[i] =
433+ slm_iscan_tmp[block_offset + disp_exchanged];
434+ }
435+
436+ it.barrier (sycl::access::fence_space::local_space);
437+ }
438+ }
439+
440+ #pragma unroll
441+ for (nwiT m_wi = 1 ; m_wi < n_wi; ++m_wi) {
442+ local_iscan[m_wi] =
443+ scan_op (local_iscan[m_wi], local_iscan[m_wi - 1 ]);
444+ }
445+ // local_iscan is now result of
446+ // inclusive scan of locally stored inputs
447+
448+ outputT wg_iscan_val;
449+ if constexpr (can_use_inclusive_scan_over_group<ScanOpT,
450+ outputT>::value)
451+ {
452+ wg_iscan_val = sycl::inclusive_scan_over_group (
453+ it.get_group (), local_iscan.back (), scan_op, identity);
454+ }
455+ else {
456+ wg_iscan_val = su_ns::custom_inclusive_scan_over_group (
457+ it.get_group (), slm_iscan_tmp, local_iscan.back (), scan_op);
458+ // ensure all finished reading from SLM, to avoid race condition
459+ // with subsequent writes into SLM
460+ it.barrier (sycl::access::fence_space::local_space);
461+ }
462+
463+ slm_iscan_tmp[(lid + 1 ) % wg_size] = wg_iscan_val;
464+ it.barrier (sycl::access::fence_space::local_space);
465+ const outputT modifier = (lid == 0 ) ? identity : slm_iscan_tmp[lid];
466+
467+ #pragma unroll
468+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
469+ local_iscan[m_wi] = scan_op (local_iscan[m_wi], modifier);
470+ }
471+
472+ it.barrier (sycl::access::fence_space::local_space);
473+
474+ // convert back to blocked layout
475+ {
476+ {
477+ const std::uint32_t local_offset0 = lid * n_wi;
478+ #pragma unroll
479+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
480+ slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi];
481+ }
482+
483+ it.barrier (sycl::access::fence_space::local_space);
484+ }
485+ }
486+
487+ {
488+ const std::uint32_t block_offset =
489+ sgroup_id * sgSize * n_wi + lane_id;
490+ #pragma unroll
491+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
492+ const std::uint32_t m_wi_scaled = m_wi * sgSize;
493+ const std::size_t out_id = inp_id0 + m_wi_scaled;
494+ if (out_id < acc_nelems) {
495+ output[out_iter_offset + out_indexer (out_id)] =
496+ slm_iscan_tmp[block_offset + m_wi_scaled];
497+ }
498+ }
499+ }
500+ });
501+ });
502+
503+ return inc_scan_phase1_ev;
504+ }
505+
506+ template <typename inputT,
507+ typename outputT,
508+ nwiT n_wi,
509+ typename IterIndexerT,
510+ typename InpIndexerT,
511+ typename OutIndexerT,
512+ typename TransformerT,
513+ typename ScanOpT,
514+ bool include_initial = false >
515+ sycl::event
516+ inclusive_scan_base_step (sycl::queue &exec_q,
517+ const std::uint32_t wg_size,
518+ const std::size_t iter_nelems,
519+ const std::size_t acc_nelems,
520+ const inputT *input,
521+ outputT *output,
522+ const std::size_t s0,
523+ const std::size_t s1,
524+ const IterIndexerT &iter_indexer,
525+ const InpIndexerT &inp_indexer,
526+ const OutIndexerT &out_indexer,
527+ TransformerT transformer,
528+ const ScanOpT &scan_op,
529+ outputT identity,
530+ std::size_t &acc_groups,
531+ const std::vector<sycl::event> &depends = {})
532+ {
533+ // For small stride use striped load/store.
534+ // Threshold value chosen experimentally.
535+ if (s1 <= 16 ) {
536+ return inclusive_scan_base_step_striped<
537+ inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
538+ TransformerT, ScanOpT, include_initial>(
539+ exec_q, wg_size, iter_nelems, acc_nelems, input, output, s0, s1,
540+ iter_indexer, inp_indexer, out_indexer, transformer, scan_op,
541+ identity, acc_groups, depends);
542+ }
543+ else {
544+ return inclusive_scan_base_step_blocked<
545+ inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
546+ TransformerT, ScanOpT, include_initial>(
547+ exec_q, wg_size, iter_nelems, acc_nelems, input, output, s0, s1,
548+ iter_indexer, inp_indexer, out_indexer, transformer, scan_op,
549+ identity, acc_groups, depends);
550+ }
551+ }
552+
299553template <typename inputT,
300554 typename outputT,
301555 nwiT n_wi,
0 commit comments