3131#include < oneapi/mkl/rng/device.hpp>
3232
3333// dpctl tensor headers
34- #include " kernels/alignment.hpp"
35- #include " utils/offset_utils.hpp"
34+ // #include "utils/offset_utils.hpp"
3635
3736namespace dpnp
3837{
@@ -48,9 +47,6 @@ namespace details
4847{
4948namespace py = pybind11;
5049
51- using dpctl::tensor::kernels::alignment_utils::is_aligned;
52- using dpctl::tensor::kernels::alignment_utils::required_alignment;
53-
5450namespace mkl_rng_dev = oneapi::mkl::rng::device;
5551
5652/* ! @brief Functor for unary function evaluation on contiguous array */
@@ -67,7 +63,7 @@ struct RngContigFunctor
6763 const std::uint32_t seed_;
6864 const DataT mean_;
6965 const DataT stddev_;
70- ResT *res_ = nullptr ;
66+ ResT * const res_ = nullptr ;
7167 const size_t nelems_;
7268
7369public:
@@ -84,10 +80,10 @@ struct RngContigFunctor
8480 const std::uint8_t sg_size = sg.get_local_range ()[0 ];
8581 const std::uint8_t max_sg_size = sg.get_max_local_range ()[0 ];
8682
87- auto engine = mkl_rng_dev::mrg32k3a<vec_sz>(seed_, nelems_ * global_id);
83+ auto engine = mkl_rng_dev::mrg32k3a<vec_sz>(seed_, nelems_ * global_id); // offset is questionable...
8884 mkl_rng_dev::gaussian<DataT, Method> distr (mean_, stddev_);
8985
90- if (enable_sg_load) {
86+ if constexpr (enable_sg_load) {
9187 const size_t base = items_per_wi * vec_sz * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) + sg.get_group_id ()[0 ] * max_sg_size);
9288
9389 if ((sg_size == max_sg_size) && (base + items_per_wi * vec_sz * sg_size < nelems_)) {
@@ -118,38 +114,38 @@ struct RngContigFunctor
118114 }
119115};
120116
121- template <typename DataT,
122- typename ResT = DataT,
123- typename Method = mkl_rng_dev::gaussian_method::by_default,
124- typename IndexerT = ResT,
125- typename UnaryOpT = ResT>
126- struct RngStridedFunctor
127- {
128- private:
129- const std::uint32_t seed_;
130- const double mean_;
131- const double stddev_;
132- ResT *res_ = nullptr ;
133- IndexerT out_indexer_;
134-
135- public:
136- RngStridedFunctor (const std::uint32_t seed, const double mean, const double stddev, ResT *res_p, IndexerT out_indexer)
137- : seed_(seed), mean_(mean), stddev_(stddev), res_(res_p), out_indexer_(out_indexer)
138- {
139- }
140-
141- void operator ()(sycl::id<1 > wid) const
142- {
143- const auto res_offset = out_indexer_ (wid.get (0 ));
144-
145- // UnaryOpT op{};
146-
147- auto engine = mkl_rng_dev::mrg32k3a (seed_);
148- mkl_rng_dev::gaussian<DataT, Method> distr (mean_, stddev_);
149-
150- res_[res_offset] = mkl_rng_dev::generate (distr, engine);
151- }
152- };
117+ // template <typename DataT,
118+ // typename ResT = DataT,
119+ // typename Method = mkl_rng_dev::gaussian_method::by_default,
120+ // typename IndexerT = ResT,
121+ // typename UnaryOpT = ResT>
122+ // struct RngStridedFunctor
123+ // {
124+ // private:
125+ // const std::uint32_t seed_;
126+ // const double mean_;
127+ // const double stddev_;
128+ // ResT *res_ = nullptr;
129+ // IndexerT out_indexer_;
130+
131+ // public:
132+ // RngStridedFunctor(const std::uint32_t seed, const double mean, const double stddev, ResT *res_p, IndexerT out_indexer)
133+ // : seed_(seed), mean_(mean), stddev_(stddev), res_(res_p), out_indexer_(out_indexer)
134+ // {
135+ // }
136+
137+ // void operator()(sycl::id<1> wid) const
138+ // {
139+ // const auto res_offset = out_indexer_(wid.get(0));
140+
141+ // // UnaryOpT op{};
142+
143+ // auto engine = mkl_rng_dev::mrg32k3a(seed_);
144+ // mkl_rng_dev::gaussian<DataT, Method> distr(mean_, stddev_);
145+
146+ // res_[res_offset] = mkl_rng_dev::generate(distr, engine);
147+ // }
148+ // };
153149} // namespace details
154150} // namespace device
155151} // namespace rng
0 commit comments