@@ -125,54 +125,56 @@ sycl::event inclusive_scan_rec(sycl::queue &exec_q,
125125 auto lws = sycl::range<1 >(wg_size);
126126 auto gws = sycl::range<1 >(n_groups * wg_size);
127127
128+ auto ndRange = sycl::nd_range<1 >(gws, lws);
129+
128130 slmT slm_iscan_tmp (lws, cgh);
129131
130- cgh.parallel_for <class inclusive_scan_rec_local_scan_krn <
131- inputT, outputT, n_wi, IndexerT, decltype (transformer)>>(
132- sycl::nd_range<1 >(gws, lws), [=, slm_iscan_tmp = std::move (slm_iscan_tmp)](sycl::nd_item<1 > it)
133- {
134- auto chunk_gid = it.get_global_id (0 );
135- auto lid = it.get_local_id (0 );
132+ using KernelName = inclusive_scan_rec_local_scan_krn<
133+ inputT, outputT, n_wi, IndexerT, decltype (transformer)>;
134+
135+ cgh.parallel_for <KernelName>(ndRange, [=, slm_iscan_tmp = std::move (
136+ slm_iscan_tmp)](
137+ sycl::nd_item<1 > it) {
138+ auto chunk_gid = it.get_global_id (0 );
139+ auto lid = it.get_local_id (0 );
136140
137- std::array<size_t , n_wi> local_isum;
141+ std::array<size_t , n_wi> local_isum;
138142
139- size_t i = chunk_gid * n_wi;
140- for (size_t m_wi = 0 ; m_wi < n_wi; ++m_wi) {
141- constexpr outputT out_zero (0 );
143+ size_t i = chunk_gid * n_wi;
144+ for (size_t m_wi = 0 ; m_wi < n_wi; ++m_wi) {
145+ constexpr outputT out_zero (0 );
142146
143- local_isum[m_wi] =
144- (i + m_wi < n_elems)
145- ? transformer (input[indexer (s0 + s1 * (i + m_wi))])
146- : out_zero;
147- }
147+ local_isum[m_wi] =
148+ (i + m_wi < n_elems)
149+ ? transformer (input[indexer (s0 + s1 * (i + m_wi))])
150+ : out_zero;
151+ }
148152
149- // local_isum is now result of
150- // inclusive scan of locally stored mask indicators
151153#pragma unroll
152- for (size_t m_wi = 1 ; m_wi < n_wi; ++m_wi) {
153- local_isum[m_wi] += local_isum[m_wi - 1 ];
154- }
154+ for (size_t m_wi = 1 ; m_wi < n_wi; ++m_wi) {
155+ local_isum[m_wi] += local_isum[m_wi - 1 ];
156+ }
157+ // local_isum is now result of
158+ // inclusive scan of locally stored inputs
155159
156- size_t wg_iscan_val =
157- sycl::inclusive_scan_over_group (it.get_group (),
158- local_isum.back (),
159- sycl::plus<size_t >(),
160- size_t (0 ));
160+ size_t wg_iscan_val = sycl::inclusive_scan_over_group (
161+ it.get_group (), local_isum.back (), sycl::plus<size_t >(),
162+ size_t (0 ));
161163
162- slm_iscan_tmp[(lid + 1 ) % wg_size] = wg_iscan_val;
163- it.barrier (sycl::access::fence_space::local_space);
164- size_t addand = (lid == 0 ) ? 0 : slm_iscan_tmp[lid];
165- it.barrier (sycl::access::fence_space::local_space);
164+ slm_iscan_tmp[(lid + 1 ) % wg_size] = wg_iscan_val;
165+ it.barrier (sycl::access::fence_space::local_space);
166+ size_t addand = (lid == 0 ) ? 0 : slm_iscan_tmp[lid];
166167
167168#pragma unroll
168- for (size_t m_wi = 0 ; m_wi < n_wi; ++m_wi) {
169- local_isum[m_wi] += addand;
170- }
171-
172- for (size_t m_wi = 0 ; m_wi < n_wi && i + m_wi < n_elems; ++m_wi) {
173- output[i + m_wi] = local_isum[m_wi];
174- }
175- });
169+ for (size_t m_wi = 0 ; m_wi < n_wi; ++m_wi) {
170+ local_isum[m_wi] += addand;
171+ }
172+
173+ for (size_t m_wi = 0 ; m_wi < n_wi && i + m_wi < n_elems; ++m_wi)
174+ {
175+ output[i + m_wi] = local_isum[m_wi];
176+ }
177+ });
176178 });
177179
178180 sycl::event out_event = inc_scan_phase1_ev;
0 commit comments