3333// dpctl tensor headers
3434#include " kernels/alignment.hpp"
3535
36- using dpctl::tensor::kernels::alignment_utils::disabled_sg_loadstore_wrapper_krn;
37- using dpctl::tensor::kernels::alignment_utils::is_aligned;
38- using dpctl::tensor::kernels::alignment_utils::required_alignment;
39-
4036#include " common_impl.hpp"
4137#include " gaussian.hpp"
4238
4339#include " engine/engine_base.hpp"
4440#include " engine/engine_builder.hpp"
4541
46- // #include "dpnp_utils.hpp"
42+ #include " dispatch/matrix.hpp"
43+ #include " dispatch/table_builder.hpp"
44+
4745
4846namespace dpnp
4947{
@@ -55,26 +53,31 @@ namespace rng
5553{
5654namespace device
5755{
56+ namespace dpctl_krn_ns = dpctl::tensor::kernels::alignment_utils;
5857namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
5958namespace mkl_rng_dev = oneapi::mkl::rng::device;
6059namespace py = pybind11;
6160namespace type_utils = dpctl::tensor::type_utils;
6261
62+ using dpctl_krn_ns::disabled_sg_loadstore_wrapper_krn;
63+ using dpctl_krn_ns::is_aligned;
64+ using dpctl_krn_ns::required_alignment;
65+
6366constexpr int no_of_methods = 2 ; // number of methods of gaussian distribution
6467
6568template <typename DataT, typename Method>
66- struct GaussianDistr
69+ struct DistributorBuilder
6770{
6871private:
6972 const DataT mean_;
7073 const DataT stddev_;
7174
7275public:
73- using method_type = Method;
7476 using result_type = DataT;
77+ using method_type = Method;
7578 using distr_type = typename mkl_rng_dev::gaussian<DataT, Method>;
7679
77- GaussianDistr (const DataT mean, const DataT stddev)
80+ DistributorBuilder (const DataT mean, const DataT stddev)
7881 : mean_(mean), stddev_(stddev)
7982 {
8083 }
@@ -128,23 +131,23 @@ static sycl::event gaussian_impl(engine::EngineBase *engine,
128131 EngineBuilderT eng_builder (engine);
129132 eng_builder.print (); // TODO: remove
130133
131- using GaussianDistrT = GaussianDistr <DataT, Method>;
132- GaussianDistrT distr (mean, stddev);
134+ using DistributorBuilderT = DistributorBuilder <DataT, Method>;
135+ DistributorBuilderT dist_builder (mean, stddev);
133136
134137 if (is_aligned<required_alignment>(out_ptr)) {
135138 constexpr bool enable_sg_load = true ;
136139 using KernelName = gaussian_kernel<EngineT, DataT, Method, items_per_wi>;
137140
138141 cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
139- details::RngContigFunctor<EngineBuilderT, DataT, GaussianDistrT, items_per_wi, enable_sg_load>(eng_builder, distr , out, n));
142+ details::RngContigFunctor<EngineBuilderT, DistributorBuilderT, items_per_wi, enable_sg_load>(eng_builder, dist_builder , out, n));
140143 }
141144 else {
142145 constexpr bool disable_sg_load = false ;
143146 using InnerKernelName = gaussian_kernel<EngineT, DataT, Method, items_per_wi>;
144147 using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
145148
146149 cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
147- details::RngContigFunctor<EngineBuilderT, DataT, GaussianDistrT, items_per_wi, disable_sg_load>(eng_builder, distr , out, n));
150+ details::RngContigFunctor<EngineBuilderT, DistributorBuilderT, items_per_wi, disable_sg_load>(eng_builder, dist_builder , out, n));
148151 }
149152 });
150153 } catch (oneapi::mkl::exception const &e) {
@@ -225,97 +228,12 @@ std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
225228 return std::make_pair (ht_ev, gaussian_ev);
226229}
227230
228- template <typename funcPtrT,
229- template <typename fnT, typename E, typename T, typename M> typename factory,
230- int _no_of_engines,
231- int _no_of_types,
232- int _no_of_methods>
233- class Dispatch3DTableBuilder
234- {
235- private:
236- template <typename E, typename T>
237- const std::vector<funcPtrT> row_per_method () const
238- {
239- std::vector<funcPtrT> per_method = {
240- factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::by_default>{}.get (),
241- factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::box_muller2>{}.get (),
242- };
243- assert (per_method.size () == _no_of_methods);
244- return per_method;
245- }
246-
247- template <typename E>
248- auto table_per_type_and_method () const
249- {
250- std::vector<std::vector<funcPtrT>>
251- table_by_type = {row_per_method<E, bool >(),
252- row_per_method<E, int8_t >(),
253- row_per_method<E, uint8_t >(),
254- row_per_method<E, int16_t >(),
255- row_per_method<E, uint16_t >(),
256- row_per_method<E, int32_t >(),
257- row_per_method<E, uint32_t >(),
258- row_per_method<E, int64_t >(),
259- row_per_method<E, uint64_t >(),
260- row_per_method<E, sycl::half>(),
261- row_per_method<E, float >(),
262- row_per_method<E, double >(),
263- row_per_method<E, std::complex <float >>(),
264- row_per_method<E, std::complex <double >>()};
265- assert (table_by_type.size () == _no_of_types);
266- return table_by_type;
267- }
268-
269- public:
270- Dispatch3DTableBuilder () = default ;
271- ~Dispatch3DTableBuilder () = default ;
272-
273- void populate (funcPtrT table[][_no_of_types][_no_of_methods]) const
274- {
275- const auto map_by_engine = {table_per_type_and_method<mkl_rng_dev::mrg32k3a<8 >>()};
276- assert (map_by_engine.size () == _no_of_engines);
277-
278- std::uint16_t engine_id = 0 ;
279- for (auto &table_by_type : map_by_engine) {
280- std::uint16_t type_id = 0 ;
281- for (auto &row_by_method : table_by_type) {
282- std::uint16_t method_id = 0 ;
283- for (auto &fn_ptr : row_by_method) {
284- table[engine_id][type_id][method_id] = fn_ptr;
285- ++method_id;
286- }
287- ++type_id;
288- }
289- ++engine_id;
290- }
291- }
292- };
293-
294- template <typename Ty, typename ArgTy, typename Method, typename argMethod>
295- struct TypePairDefinedEntry : std::bool_constant<std::is_same_v<Ty, ArgTy> &&
296- std::is_same_v<Method, argMethod>>
297- {
298- static constexpr bool is_defined = true ;
299- };
300-
301- template <typename T, typename M>
302- struct GaussianTypePairSupportFactory
303- {
304- static constexpr bool is_defined = std::disjunction<
305- TypePairDefinedEntry<T, double , M, mkl_rng_dev::gaussian_method::by_default>,
306- TypePairDefinedEntry<T, double , M, mkl_rng_dev::gaussian_method::box_muller2>,
307- TypePairDefinedEntry<T, float , M, mkl_rng_dev::gaussian_method::by_default>,
308- TypePairDefinedEntry<T, float , M, mkl_rng_dev::gaussian_method::box_muller2>,
309- // fall-through
310- dpctl_td_ns::NotDefinedEntry>::is_defined;
311- };
312-
313231template <typename fnT, typename E, typename T, typename M>
314232struct GaussianContigFactory
315233{
316234 fnT get ()
317235 {
318- if constexpr (GaussianTypePairSupportFactory<T, M>::is_defined) {
236+ if constexpr (dispatch:: GaussianTypePairSupportFactory<T, M>::is_defined) {
319237 return gaussian_impl<E, T, M>;
320238 }
321239 else {
@@ -326,7 +244,7 @@ struct GaussianContigFactory
326244
327245void init_gaussian_dispatch_table (void )
328246{
329- Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory, engine::no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
247+ dispatch:: Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory, engine::no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
330248 contig.populate (gaussian_dispatch_table);
331249}
332250} // namespace device
0 commit comments