@@ -57,6 +57,40 @@ namespace mkl_rng_dev = oneapi::mkl::rng::device;
5757namespace py = pybind11;
5858namespace type_utils = dpctl::tensor::type_utils;
5959
60+ constexpr int num_methods = 2 ; // number of methods of gaussian distribution
61+
62+ // static mkl_rng_dev::gaussian_method get_method(const std::int8_t method) {
63+ // switch (method) {
64+ // case 0: return mkl_rng_dev::gaussian_method::by_default;
65+ // case 1: return mkl_rng_dev::gaussian_method::by_default;
66+ // default:
67+ // throw py::value_error();
68+ // }
69+ // }
70+
71+ template <typename DataT, typename Method>
72+ struct GaussianDistr
73+ {
74+ private:
75+ const DataT mean_;
76+ const DataT stddev_;
77+
78+ public:
79+ using method_type = Method;
80+ using result_type = DataT;
81+ using distr_type = typename mkl_rng_dev::gaussian<DataT, Method>;
82+
83+ GaussianDistr (const DataT mean, const DataT stddev)
84+ : mean_(mean), stddev_(stddev)
85+ {
86+ }
87+
88+ inline auto operator ()(void ) const
89+ {
90+ return distr_type (mean_, stddev_);
91+ }
92+ };
93+
6094typedef sycl::event (*gaussian_impl_fn_ptr_t )(sycl::queue &,
6195 const std::uint32_t ,
6296 const double ,
@@ -65,13 +99,12 @@ typedef sycl::event (*gaussian_impl_fn_ptr_t)(sycl::queue &,
6599 char *,
66100 const std::vector<sycl::event> &);
67101
68- static gaussian_impl_fn_ptr_t gaussian_dispatch_vector [dpctl_td_ns::num_types];
102+ static gaussian_impl_fn_ptr_t gaussian_dispatch_table [dpctl_td_ns::num_types][num_methods ];
69103
70- // template <typename DataT, typename Method = mkl_rng_dev::gaussian_method::by_default>
71- template <typename DataT, unsigned int vec_sz, unsigned int items_per_wi>
104+ template <typename DataT, typename Method, unsigned int vec_sz, unsigned int items_per_wi>
72105class gaussian_kernel ;
73106
74- template <typename DataT, typename Method = mkl_rng_dev::gaussian_method::by_default >
107+ template <typename DataT, typename Method>
75108static sycl::event gaussian_impl (sycl::queue& exec_q,
76109 const std::uint32_t seed,
77110 const double mean_val,
@@ -98,20 +131,23 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
98131 distr_event = exec_q.submit ([&](sycl::handler &cgh) {
99132 cgh.depends_on (depends);
100133
134+ using GaussianDistrT = GaussianDistr<DataT, Method>;
135+ GaussianDistrT distr (mean, stddev);
136+
101137 if (is_aligned<required_alignment>(out_ptr)) {
102138 constexpr bool enable_sg_load = true ;
103- using KernelName = gaussian_kernel<DataT, vec_sz, items_per_wi>;
139+ using KernelName = gaussian_kernel<DataT, Method, vec_sz, items_per_wi>;
104140
105141 cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
106- details::RngContigFunctor<DataT, DataT, Method , DataT, vec_sz, items_per_wi, enable_sg_load>(seed, mean, stddev , out, n));
142+ details::RngContigFunctor<DataT, GaussianDistrT, DataT , DataT, vec_sz, items_per_wi, enable_sg_load>(seed, distr , out, n));
107143 }
108144 else {
109145 constexpr bool disable_sg_load = false ;
110- using InnerKernelName = gaussian_kernel<DataT, vec_sz, items_per_wi>;
146+ using InnerKernelName = gaussian_kernel<DataT, Method, vec_sz, items_per_wi>;
111147 using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
112148
113149 cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
114- details::RngContigFunctor<DataT, DataT, Method , DataT, vec_sz, items_per_wi, disable_sg_load>(seed, mean, stddev , out, n));
150+ details::RngContigFunctor<DataT, GaussianDistrT, DataT , DataT, vec_sz, items_per_wi, disable_sg_load>(seed, distr , out, n));
115151 }
116152 });
117153 } catch (oneapi::mkl::exception const &e) {
@@ -129,6 +165,7 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
129165}
130166
131167std::pair<sycl::event, sycl::event> gaussian (sycl::queue exec_q,
168+ const std::uint8_t method_id,
132169 const std::uint32_t seed,
133170 const double mean,
134171 const double stddev,
@@ -166,10 +203,14 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
166203 throw std::runtime_error (" Only population of contiguous array is supported." );
167204 }
168205
206+ if (method_id >= num_methods) {
207+ throw std::runtime_error (" Unknown method=" + std::to_string (method_id) + " for gaussian distribution." );
208+ }
209+
169210 auto array_types = dpctl_td_ns::usm_ndarray_types ();
170211 int res_type_id = array_types.typenum_to_lookup_id (res.get_typenum ());
171212
172- auto gaussian_fn = gaussian_dispatch_vector [res_type_id];
213+ auto gaussian_fn = gaussian_dispatch_table [res_type_id][method_id ];
173214 if (gaussian_fn == nullptr ) {
174215 throw py::value_error (" No gaussian implementation defined for a required type" );
175216 }
@@ -181,36 +222,95 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
181222 return std::make_pair (ht_ev, gaussian_ev);
182223}
183224
184- template <typename T>
225+ template <typename funcPtrT,
226+ template <typename fnT, typename D, typename S> typename factory,
227+ int _num_types,
228+ int _num_methods>
229+ // class DispatchTableBuilder : public dpctl_td_ns::DispatchTableBuilder<funcPtrT, factory, _num_types>
230+ class DispatchTableBuilder /* : public dpctl_td_ns::DispatchTableBuilder<funcPtrT, factory, _num_types>*/
231+ {
232+ private:
233+ template <typename dstTy>
234+ const std::vector<funcPtrT> row_per_method () const
235+ {
236+ std::vector<funcPtrT> per_method = {
237+ factory<funcPtrT, dstTy, mkl_rng_dev::gaussian_method::by_default>{}.get (),
238+ factory<funcPtrT, dstTy, mkl_rng_dev::gaussian_method::box_muller2>{}.get (),
239+ };
240+ assert (per_method.size () == _num_methods);
241+ return per_method;
242+ }
243+
244+ public:
245+ DispatchTableBuilder () = default ;
246+ ~DispatchTableBuilder () = default ;
247+
248+ void populate (funcPtrT table[][_num_methods]) const
249+ {
250+ const auto map_by_dst_type = {row_per_method<bool >(),
251+ row_per_method<int8_t >(),
252+ row_per_method<uint8_t >(),
253+ row_per_method<int16_t >(),
254+ row_per_method<uint16_t >(),
255+ row_per_method<int32_t >(),
256+ row_per_method<uint32_t >(),
257+ row_per_method<int64_t >(),
258+ row_per_method<uint64_t >(),
259+ row_per_method<sycl::half>(),
260+ row_per_method<float >(),
261+ row_per_method<double >(),
262+ row_per_method<std::complex <float >>(),
263+ row_per_method<std::complex <double >>()};
264+ assert (map_by_dst_type.size () == _num_types);
265+ int dst_id = 0 ;
266+ for (auto &row : map_by_dst_type) {
267+ int src_id = 0 ;
268+ for (auto &fn_ptr : row) {
269+ table[dst_id][src_id] = fn_ptr;
270+ ++src_id;
271+ }
272+ ++dst_id;
273+ }
274+ }
275+ };
276+
277+ template <typename Ty, typename ArgTy, typename Method, typename argMethod>
278+ struct TypePairDefinedEntry : std::bool_constant<std::is_same_v<Ty, ArgTy> &&
279+ std::is_same_v<Method, argMethod>>
280+ {
281+ static constexpr bool is_defined = true ;
282+ };
283+
284+ template <typename T, typename M>
185285struct GaussianTypePairSupportFactory
186286{
187287 static constexpr bool is_defined = std::disjunction<
188- dpctl_td_ns::TypePairDefinedEntry<T, double , T, double >,
189- dpctl_td_ns::TypePairDefinedEntry<T, float , T, float >,
288+ TypePairDefinedEntry<T, double , M, mkl_rng_dev::gaussian_method::by_default>,
289+ TypePairDefinedEntry<T, double , M, mkl_rng_dev::gaussian_method::box_muller2>,
290+ TypePairDefinedEntry<T, float , M, mkl_rng_dev::gaussian_method::by_default>,
291+ TypePairDefinedEntry<T, float , M, mkl_rng_dev::gaussian_method::box_muller2>,
190292 // fall-through
191293 dpctl_td_ns::NotDefinedEntry>::is_defined;
192294};
193295
194- template <typename fnT, typename T>
296+ template <typename fnT, typename T, typename M >
195297struct GaussianContigFactory
196298{
197299 fnT get ()
198300 {
199- if constexpr (GaussianTypePairSupportFactory<T>::is_defined) {
200- return gaussian_impl<T>;
301+ if constexpr (GaussianTypePairSupportFactory<T, M >::is_defined) {
302+ return gaussian_impl<T, M >;
201303 }
202304 else {
203305 return nullptr ;
204306 }
205307 }
206308};
207309
208- void init_gaussian_dispatch_vector (void )
310+ void init_gaussian_dispatch_table (void )
209311{
210- dpctl_td_ns::DispatchVectorBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory,
211- dpctl_td_ns::num_types>
212- contig;
213- contig.populate_dispatch_vector (gaussian_dispatch_vector);
312+ DispatchTableBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory, dpctl_td_ns::num_types, num_methods> contig;
313+ contig.populate (gaussian_dispatch_table);
214314}
215315} // namespace device
216316} // namespace rng
0 commit comments