2323// ===---------------------------------------------------------------------===//
2424
2525#pragma once
26+ #include < cstddef>
2627#include < cstdint>
2728#include < limits>
2829#include < sycl/sycl.hpp>
@@ -55,7 +56,7 @@ struct MaskedExtractStridedFunctor
5556 MaskedExtractStridedFunctor (const dataT *src_data_p,
5657 const indT *cumsum_data_p,
5758 dataT *dst_data_p,
58- size_t masked_iter_size,
59+ std:: size_t masked_iter_size,
5960 const OrthogIndexerT &orthog_src_dst_indexer_,
6061 const MaskedSrcIndexerT &masked_src_indexer_,
6162 const MaskedDstIndexerT &masked_dst_indexer_,
@@ -81,7 +82,7 @@ struct MaskedExtractStridedFunctor
8182
8283 const std::size_t max_offset = masked_nelems + 1 ;
8384 for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
84- const size_t offset = masked_block_start + i;
85+ const std:: size_t offset = masked_block_start + i;
8586 lacc[i] = (offset == 0 ) ? indT (0 )
8687 : (offset < max_offset) ? cumsum[offset - 1 ]
8788 : cumsum[masked_nelems - 1 ] + 1 ;
@@ -99,9 +100,10 @@ struct MaskedExtractStridedFunctor
99100 if (mask_set && (masked_i < masked_nelems)) {
100101 const auto &orthog_offsets = orthog_src_dst_indexer (orthog_i);
101102
102- const size_t total_src_offset = masked_src_indexer (masked_i) +
103- orthog_offsets.get_first_offset ();
104- const size_t total_dst_offset =
103+ const std::size_t total_src_offset =
104+ masked_src_indexer (masked_i) +
105+ orthog_offsets.get_first_offset ();
106+ const std::size_t total_dst_offset =
105107 masked_dst_indexer (current_running_count - 1 ) +
106108 orthog_offsets.get_second_offset ();
107109
@@ -113,7 +115,7 @@ struct MaskedExtractStridedFunctor
113115 const dataT *src = nullptr ;
114116 const indT *cumsum = nullptr ;
115117 dataT *dst = nullptr ;
116- const size_t masked_nelems = 0 ;
118+ const std:: size_t masked_nelems = 0 ;
117119 // has nd, shape, src_strides, dst_strides for
118120 // dimensions that ARE NOT masked
119121 const OrthogIndexerT orthog_src_dst_indexer;
@@ -136,7 +138,7 @@ struct MaskedPlaceStridedFunctor
136138 MaskedPlaceStridedFunctor (dataT *dst_data_p,
137139 const indT *cumsum_data_p,
138140 const dataT *rhs_data_p,
139- size_t masked_iter_size,
141+ std:: size_t masked_iter_size,
140142 const OrthogIndexerT &orthog_dst_rhs_indexer_,
141143 const MaskedDstIndexerT &masked_dst_indexer_,
142144 const MaskedRhsIndexerT &masked_rhs_indexer_,
@@ -157,12 +159,12 @@ struct MaskedPlaceStridedFunctor
157159 const std::uint32_t l_i = ndit.get_local_id (1 );
158160 const std::uint32_t lws = ndit.get_local_range (1 );
159161
160- const size_t masked_i = ndit.get_global_id (1 );
161- const size_t masked_block_start = masked_i - l_i;
162+ const std:: size_t masked_i = ndit.get_global_id (1 );
163+ const std:: size_t masked_block_start = masked_i - l_i;
162164
163165 const std::size_t max_offset = masked_nelems + 1 ;
164166 for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
165- const size_t offset = masked_block_start + i;
167+ const std:: size_t offset = masked_block_start + i;
166168 lacc[i] = (offset == 0 ) ? indT (0 )
167169 : (offset < max_offset) ? cumsum[offset - 1 ]
168170 : cumsum[masked_nelems - 1 ] + 1 ;
@@ -180,9 +182,10 @@ struct MaskedPlaceStridedFunctor
180182 if (mask_set && (masked_i < masked_nelems)) {
181183 const auto &orthog_offsets = orthog_dst_rhs_indexer (orthog_i);
182184
183- const size_t total_dst_offset = masked_dst_indexer (masked_i) +
184- orthog_offsets.get_first_offset ();
185- const size_t total_rhs_offset =
185+ const std::size_t total_dst_offset =
186+ masked_dst_indexer (masked_i) +
187+ orthog_offsets.get_first_offset ();
188+ const std::size_t total_rhs_offset =
186189 masked_rhs_indexer (current_running_count - 1 ) +
187190 orthog_offsets.get_second_offset ();
188191
@@ -194,7 +197,7 @@ struct MaskedPlaceStridedFunctor
194197 dataT *dst = nullptr ;
195198 const indT *cumsum = nullptr ;
196199 const dataT *rhs = nullptr ;
197- const size_t masked_nelems = 0 ;
200+ const std:: size_t masked_nelems = 0 ;
198201 // has nd, shape, dst_strides, rhs_strides for
199202 // dimensions that ARE NOT masked
200203 const OrthogIndexerT orthog_dst_rhs_indexer;
@@ -450,8 +453,8 @@ sycl::event masked_extract_some_slices_strided_impl(
450453
451454 const std::size_t lws = get_lws (masked_extent);
452455
453- const size_t n_groups = ((masked_extent + lws - 1 ) / lws);
454- const size_t orthog_extent = static_cast <size_t >(orthog_nelems);
456+ const std:: size_t n_groups = ((masked_extent + lws - 1 ) / lws);
457+ const std:: size_t orthog_extent = static_cast <std:: size_t >(orthog_nelems);
455458
456459 sycl::range<2 > gRange {orthog_extent, n_groups * lws};
457460 sycl::range<2 > lRange{1 , lws};
@@ -809,7 +812,7 @@ sycl::event non_zero_indexes_impl(sycl::queue &exec_q,
809812 const std::size_t masked_block_start = group_i * lws;
810813
811814 for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
812- const size_t offset = masked_block_start + i;
815+ const std:: size_t offset = masked_block_start + i;
813816 lacc[i] = (offset == 0 ) ? indT1 (0 )
814817 : (offset - 1 < masked_extent)
815818 ? cumsum_data[offset - 1 ]
0 commit comments