@@ -52,12 +52,15 @@ template <typename T, int Dims>
5252struct CachedData
5353{
5454 static constexpr bool const sync_after_init = true ;
55- using pointer_type = T *;
55+ using Shape = sycl::range<Dims>;
56+ using value_type = T;
57+ using pointer_type = value_type *;
58+ static constexpr auto dims = Dims;
5659
57- using ncT = typename std::remove_const<T >::type;
60+ using ncT = typename std::remove_const<value_type >::type;
5861 using LocalData = sycl::local_accessor<ncT, Dims>;
5962
60- CachedData (T *global_data, sycl::range<Dims> shape, sycl::handler &cgh)
63+ CachedData (T *global_data, Shape shape, sycl::handler &cgh)
6164 {
6265 this ->global_data = global_data;
6366 local_data = LocalData (shape, cgh);
@@ -71,13 +74,13 @@ struct CachedData
7174 template <int _Dims>
7275 void init (const sycl::nd_item<_Dims> &item) const
7376 {
74- int32_t llid = item.get_local_linear_id ();
77+ uint32_t llid = item.get_local_linear_id ();
7578 auto local_ptr = &local_data[0 ];
76- int32_t size = local_data.size ();
79+ uint32_t size = local_data.size ();
7780 auto group = item.get_group ();
78- int32_t local_size = group.get_local_linear_range ();
81+ uint32_t local_size = group.get_local_linear_range ();
7982
80- for (int32_t i = llid; i < size; i += local_size) {
83+ for (uint32_t i = llid; i < size; i += local_size) {
8184 local_ptr[i] = global_data[i];
8285 }
8386 }
@@ -87,17 +90,30 @@ struct CachedData
8790 return local_data.size ();
8891 }
8992
93+ T &operator [](const sycl::id<Dims> &id) const
94+ {
95+ return local_data[id];
96+ }
97+
98+ template <typename = std::enable_if_t <Dims == 1 >>
99+ T &operator [](const size_t id) const
100+ {
101+ return local_data[id];
102+ }
103+
90104private:
91105 LocalData local_data;
92- T *global_data = nullptr ;
106+ value_type *global_data = nullptr ;
93107};
94108
95109template <typename T, int Dims>
96110struct UncachedData
97111{
98112 static constexpr bool const sync_after_init = false ;
99113 using Shape = sycl::range<Dims>;
100- using pointer_type = T *;
114+ using value_type = T;
115+ using pointer_type = value_type *;
116+ static constexpr auto dims = Dims;
101117
102118 UncachedData (T *global_data, const Shape &shape, sycl::handler &)
103119 {
@@ -120,6 +136,17 @@ struct UncachedData
120136 return _shape.size ();
121137 }
122138
139+ T &operator [](const sycl::id<Dims> &id) const
140+ {
141+ return global_data[id];
142+ }
143+
144+ template <typename = std::enable_if_t <Dims == 1 >>
145+ T &operator [](const size_t id) const
146+ {
147+ return global_data[id];
148+ }
149+
123150private:
124151 T *global_data = nullptr ;
125152 Shape _shape;
@@ -191,15 +218,15 @@ struct HistWithLocalCopies
191218 template <int _Dims>
192219 void finalize (const sycl::nd_item<_Dims> &item) const
193220 {
194- int32_t llid = item.get_local_linear_id ();
195- int32_t bins_count = local_hist.get_range ().get (1 );
196- int32_t local_hist_count = local_hist.get_range ().get (0 );
221+ uint32_t llid = item.get_local_linear_id ();
222+ uint32_t bins_count = local_hist.get_range ().get (1 );
223+ uint32_t local_hist_count = local_hist.get_range ().get (0 );
197224 auto group = item.get_group ();
198- int32_t local_size = group.get_local_linear_range ();
225+ uint32_t local_size = group.get_local_linear_range ();
199226
200- for (int32_t i = llid; i < bins_count; i += local_size) {
227+ for (uint32_t i = llid; i < bins_count; i += local_size) {
201228 auto value = local_hist[0 ][i];
202- for (int32_t lhc = 1 ; lhc < local_hist_count; ++lhc) {
229+ for (uint32_t lhc = 1 ; lhc < local_hist_count; ++lhc) {
203230 value += local_hist[lhc][i];
204231 }
205232 if (value != T (0 )) {
@@ -290,9 +317,9 @@ class histogram_kernel;
290317
291318template <typename T, typename HistImpl, typename Edges, typename Weights>
292319void submit_histogram (const T *in,
293- size_t size,
294- size_t dims,
295- uint32_t WorkPI,
320+ const size_t size,
321+ const size_t dims,
322+ const uint32_t WorkPI,
296323 const HistImpl &hist,
297324 const Edges &edges,
298325 const Weights &weights,
0 commit comments