File tree Expand file tree Collapse file tree 1 file changed +9
-5
lines changed
dpctl/tensor/libtensor/include/utils Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -261,10 +261,13 @@ T custom_inclusive_scan_over_group(GroupT &&wg,
261261 sycl::group_barrier (wg, sycl::memory_scope::work_group);
262262 }
263263
264- if (sgr_id == 0 && lane_id < n_aggregates ) {
264+ if (sgr_id == 0 ) {
265265 const std::uint32_t offset =
266266 (large_wg) ? n_aggregates - max_sgSize : 0u ;
267- T __scan_val = (offset + lane_id > 0 )
267+ const bool in_range = (lane_id < n_aggregates);
268+ const bool in_bounds = in_range && (lane_id > 0 || large_wg);
269+
270+ T __scan_val = (in_bounds)
268271 ? local_mem_acc[(offset + lane_id) * max_sgSize - 1 ]
269272 : identity;
270273 for (std::uint32_t step = 1 ; step < sgSize; step *= 2 ) {
@@ -273,12 +276,13 @@ T custom_inclusive_scan_over_group(GroupT &&wg,
273276 (advanced_lane ? lane_id - step : lane_id);
274277 const T modifier =
275278 sycl::select_from_group (sg, __scan_val, src_lane_id);
276- if (advanced_lane) {
279+ if (advanced_lane && in_range ) {
277280 __scan_val = op (__scan_val, modifier);
278281 }
279282 }
280- sycl::group_barrier (sg);
281- local_mem_acc[(offset + lane_id) * max_sgSize - 1 ] = __scan_val;
283+ if (in_bounds) {
284+ local_mem_acc[(offset + lane_id) * max_sgSize - 1 ] = __scan_val;
285+ }
282286 }
283287 sycl::group_barrier (wg, sycl::memory_scope::work_group);
284288
You can’t perform that action at this time.
0 commit comments