@@ -91,7 +91,26 @@ struct GaussianDistr
9191 }
9292};
9393
94- typedef sycl::event (*gaussian_impl_fn_ptr_t )(sycl::queue &,
94+ template <typename EngineBase, typename MklEngineT>
95+ struct EngineDistr
96+ {
97+ private:
98+ EngineBase *engine_;
99+
100+ public:
101+ using engine_type = MklEngineT;
102+
103+ EngineDistr (EngineBase *engine) : engine_(engine)
104+ {
105+ }
106+
107+ inline auto operator ()(void ) const
108+ {
109+ return MklEngineT (engine_->seed_ , engine_->offset_ );
110+ }
111+ };
112+
113+ typedef sycl::event (*gaussian_impl_fn_ptr_t )(EngineBase *engine,
95114 const std::uint32_t ,
96115 const double ,
97116 const double ,
@@ -101,25 +120,26 @@ typedef sycl::event (*gaussian_impl_fn_ptr_t)(sycl::queue &,
101120
102121static gaussian_impl_fn_ptr_t gaussian_dispatch_table[dpctl_td_ns::num_types][num_methods];
103122
104- template <typename DataT, typename Method, unsigned int vec_sz , unsigned int items_per_wi>
123+ template <typename EngineT, typename DataT, typename Method , unsigned int items_per_wi>
105124class gaussian_kernel ;
106125
107- template <typename DataT, typename Method>
108- static sycl::event gaussian_impl (sycl::queue& exec_q ,
126+ template <typename EngineT, typename DataT, typename Method>
127+ static sycl::event gaussian_impl (EngineBase *engine ,
109128 const std::uint32_t seed,
110129 const double mean_val,
111130 const double stddev_val,
112131 const std::uint64_t n,
113132 char *out_ptr,
114133 const std::vector<sycl::event> &depends)
115134{
135+ auto exec_q = engine->get_queue ();
116136 type_utils::validate_type_for_device<DataT>(exec_q);
117137
118138 DataT *out = reinterpret_cast <DataT *>(out_ptr);
119139 DataT mean = static_cast <DataT>(mean_val);
120140 DataT stddev = static_cast <DataT>(stddev_val);
121141
122- constexpr std::size_t vec_sz = 8 ;
142+ constexpr std::size_t vec_sz = EngineT::vec_size ;
123143 constexpr std::size_t items_per_wi = 4 ;
124144 constexpr std::size_t local_size = 256 ;
125145 const std::size_t wg_items = local_size * vec_sz * items_per_wi;
@@ -131,23 +151,28 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
131151 distr_event = exec_q.submit ([&](sycl::handler &cgh) {
132152 cgh.depends_on (depends);
133153
154+ using EngineDistrT = EngineDistr<MRG32k3a, EngineT>;
155+ EngineDistrT eng (static_cast <MRG32k3a*>(engine));
156+
157+ // EngineT engine = EngineT(seed, 0);
158+
134159 using GaussianDistrT = GaussianDistr<DataT, Method>;
135160 GaussianDistrT distr (mean, stddev);
136161
137162 if (is_aligned<required_alignment>(out_ptr)) {
138163 constexpr bool enable_sg_load = true ;
139- using KernelName = gaussian_kernel<DataT, Method, vec_sz , items_per_wi>;
164+ using KernelName = gaussian_kernel<EngineT, DataT, Method , items_per_wi>;
140165
141166 cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
142- details::RngContigFunctor<DataT, GaussianDistrT, DataT, DataT, vec_sz, items_per_wi, enable_sg_load>(seed , distr, out, n));
167+ details::RngContigFunctor<EngineDistrT, DataT, GaussianDistrT, items_per_wi, enable_sg_load>(eng , distr, out, n));
143168 }
144169 else {
145170 constexpr bool disable_sg_load = false ;
146- using InnerKernelName = gaussian_kernel<DataT, Method, vec_sz , items_per_wi>;
171+ using InnerKernelName = gaussian_kernel<EngineT, DataT, Method , items_per_wi>;
147172 using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
148173
149174 cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
150- details::RngContigFunctor<DataT, GaussianDistrT, DataT, DataT, vec_sz, items_per_wi, disable_sg_load>(seed , distr, out, n));
175+ details::RngContigFunctor<EngineDistrT, DataT, GaussianDistrT, items_per_wi, disable_sg_load>(eng , distr, out, n));
151176 }
152177 });
153178 } catch (oneapi::mkl::exception const &e) {
@@ -164,7 +189,7 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
164189 return distr_event;
165190}
166191
167- std::pair<sycl::event, sycl::event> gaussian (sycl::queue exec_q ,
192+ std::pair<sycl::event, sycl::event> gaussian (EngineBase *engine ,
168193 const std::uint8_t method_id,
169194 const std::uint32_t seed,
170195 const double mean,
@@ -173,6 +198,9 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
173198 dpctl::tensor::usm_ndarray res,
174199 const std::vector<sycl::event> &depends)
175200{
201+ std::cout << engine->print () << std::endl;
202+ auto exec_q = engine->get_queue ();
203+
176204 const int res_nd = res.get_ndim ();
177205 const py::ssize_t *res_shape = res.get_shape_raw ();
178206
@@ -216,7 +244,7 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
216244 }
217245
218246 char *res_data = res.get_data ();
219- sycl::event gaussian_ev = gaussian_fn (exec_q , seed, mean, stddev, n, res_data, depends);
247+ sycl::event gaussian_ev = gaussian_fn (engine , seed, mean, stddev, n, res_data, depends);
220248
221249 sycl::event ht_ev = dpctl::utils::keep_args_alive (exec_q, {res}, {gaussian_ev});
222250 return std::make_pair (ht_ev, gaussian_ev);
@@ -299,7 +327,7 @@ struct GaussianContigFactory
299327 fnT get ()
300328 {
301329 if constexpr (GaussianTypePairSupportFactory<T, M>::is_defined) {
302- return gaussian_impl<T, M>;
330+ return gaussian_impl<mkl_rng_dev::mrg32k3a< 8 >, T, M>;
303331 }
304332 else {
305333 return nullptr ;
0 commit comments