@@ -58,55 +58,53 @@ struct RngContigFunctor
5858 EngineBuilderT engine_;
5959 DistributorBuilderT distr_;
6060 DataT * const res_ = nullptr ;
61- const size_t nelems_;
61+ const std:: size_t nelems_;
6262
6363public:
64- RngContigFunctor (EngineBuilderT& engine, DistributorBuilderT& distr, DataT *res, const size_t n_elems)
64+ RngContigFunctor (EngineBuilderT& engine, DistributorBuilderT& distr, DataT *res, const std:: size_t n_elems)
6565 : engine_(engine), distr_(distr), res_(res), nelems_(n_elems)
6666 {
6767 }
6868
6969 void operator ()(sycl::nd_item<1 > nd_it) const
7070 {
71- // auto global_id = nd_it.get_global_id();
72-
7371 auto sg = nd_it.get_sub_group ();
7472 const std::uint8_t sg_size = sg.get_local_range ()[0 ];
7573 const std::uint8_t max_sg_size = sg.get_max_local_range ()[0 ];
7674
7775 using EngineT = typename EngineBuilderT::EngineType;
78- // EngineT engine = engine_(nelems_ * global_id); // offset is questionable...
79- EngineT engine = engine_ ();
80-
8176 using DistrT = typename DistributorBuilderT::distr_type;
82- DistrT distr = distr_ ();
8377
8478 constexpr std::size_t vec_sz = EngineT::vec_size;
79+ constexpr std::size_t vi_per_wi = vec_sz * items_per_wi;
80+
81+ EngineT engine = engine_ (nd_it.get_global_id () * vi_per_wi);
82+ DistrT distr = distr_ ();
8583
8684 if constexpr (enable_sg_load) {
87- const size_t base = items_per_wi * vec_sz * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) + sg.get_group_id ()[0 ] * max_sg_size);
85+ const std:: size_t base = vi_per_wi * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) + sg.get_group_id ()[0 ] * max_sg_size);
8886
89- if ((sg_size == max_sg_size) && (base + items_per_wi * vec_sz * sg_size < nelems_)) {
87+ if ((sg_size == max_sg_size) && (base + vi_per_wi * sg_size < nelems_)) {
9088#pragma unroll
91- for (std::uint16_t it = 0 ; it < items_per_wi * vec_sz ; it += vec_sz) {
92- size_t offset = base + static_cast <size_t >(it) * static_cast <size_t >(sg_size);
89+ for (std::uint16_t it = 0 ; it < vi_per_wi ; it += vec_sz) {
90+ std:: size_t offset = base + static_cast <std:: size_t >(it) * static_cast <std:: size_t >(sg_size);
9391 auto out_multi_ptr = sycl::address_space_cast<sycl::access::address_space::global_space, sycl::access::decorated::yes>(&res_[offset]);
9492
9593 sycl::vec<DataT, vec_sz> rng_val_vec = mkl_rng_dev::generate<DistrT, EngineT>(distr, engine);
9694 sg.store <vec_sz>(out_multi_ptr, rng_val_vec);
9795 }
9896 }
9997 else {
100- for (size_t offset = base + sg.get_local_id ()[0 ]; offset < nelems_; offset += sg_size) {
98+ for (std:: size_t offset = base + sg.get_local_id ()[0 ]; offset < nelems_; offset += sg_size) {
10199 res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
102100 }
103101 }
104102 }
105103 else {
106- size_t base = nd_it.get_global_linear_id ();
104+ std:: size_t base = nd_it.get_global_linear_id ();
107105
108- base = (base / sg_size) * sg_size * items_per_wi * vec_sz + (base % sg_size);
109- for (size_t offset = base; offset < std::min (nelems_, base + sg_size * (items_per_wi * vec_sz) ); offset += sg_size)
106+ base = (base / sg_size) * sg_size * vi_per_wi + (base % sg_size);
107+ for (std:: size_t offset = base; offset < std::min (nelems_, base + sg_size * vi_per_wi ); offset += sg_size)
110108 {
111109 res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
112110 }
0 commit comments