2323// THE POSSIBILITY OF SUCH DAMAGE.
2424// *****************************************************************************
2525
26+ #pragma once
27+
2628#include < pybind11/pybind11.h>
2729#include < pybind11/stl.h>
2830#include < sycl/sycl.hpp>
2931
3032#include " dpctl4pybind11.hpp"
31- #include " hamming_kernel.hpp"
3233#include " utils/output_validation.hpp"
3334#include " utils/type_dispatch.hpp"
35+ #include " utils/type_utils.hpp"
3436
3537namespace dpnp ::extensions::window
3638{
3739
3840namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3941
40- static kernels::hamming_fn_ptr_t hamming_dispatch_table[dpctl_td_ns::num_types];
41-
4242namespace py = pybind11;
4343
44+ typedef sycl::event (*window_fn_ptr_t )(sycl::queue &,
45+ char *,
46+ const std::size_t ,
47+ const std::vector<sycl::event> &);
48+
49+ template <typename T, template <typename > class Functor >
50+ sycl::event window_impl (sycl::queue &q,
51+ char *result,
52+ const std::size_t nelems,
53+ const std::vector<sycl::event> &depends)
54+ {
55+ dpctl::tensor::type_utils::validate_type_for_device<T>(q);
56+
57+ T *res = reinterpret_cast <T *>(result);
58+
59+ sycl::event window_ev = q.submit ([&](sycl::handler &cgh) {
60+ cgh.depends_on (depends);
61+
62+ using WindowKernel = Functor<T>;
63+ cgh.parallel_for <WindowKernel>(sycl::range<1 >(nelems),
64+ WindowKernel (res, nelems));
65+ });
66+
67+ return window_ev;
68+ }
69+
70+ template <typename dispatchT>
4471std::pair<sycl::event, sycl::event>
45- py_hamming (sycl::queue &exec_q,
46- const dpctl::tensor::usm_ndarray &result,
47- const std::vector<sycl::event> &depends)
72+ py_window (sycl::queue &exec_q,
73+ const dpctl::tensor::usm_ndarray &result,
74+ const std::vector<sycl::event> &depends,
75+ const dispatchT &window_dispatch_vector)
4876{
4977 dpctl::tensor::validation::CheckWritable::throw_if_not_writable (result);
5078
@@ -71,52 +99,27 @@ std::pair<sycl::event, sycl::event>
7199 int result_typenum = result.get_typenum ();
72100 auto array_types = dpctl_td_ns::usm_ndarray_types ();
73101 int result_type_id = array_types.typenum_to_lookup_id (result_typenum);
74- auto fn = hamming_dispatch_table [result_type_id];
102+ auto fn = window_dispatch_vector [result_type_id];
75103
76104 if (fn == nullptr ) {
77105 throw std::runtime_error (" Type of given array is not supported" );
78106 }
79107
80108 char *result_typeless_ptr = result.get_data ();
81- sycl::event hamming_ev = fn (exec_q, result_typeless_ptr, nelems, depends);
109+ sycl::event window_ev = fn (exec_q, result_typeless_ptr, nelems, depends);
82110 sycl::event args_ev =
83- dpctl::utils::keep_args_alive (exec_q, {result}, {hamming_ev });
111+ dpctl::utils::keep_args_alive (exec_q, {result}, {window_ev });
84112
85- return std::make_pair (args_ev, hamming_ev );
113+ return std::make_pair (args_ev, window_ev );
86114}
87115
88- template <typename fnT, typename T>
89- struct HammingFactory
116+ template <template < typename fnT, typename T> typename factoryT >
117+ void init_window_dispatch_vectors ( window_fn_ptr_t window_dispatch_vector[])
90118{
91- fnT get ()
92- {
93- if constexpr (std::is_floating_point_v<T>) {
94- return kernels::hamming_impl<T>;
95- }
96- else {
97- return nullptr ;
98- }
99- }
100- };
101-
102- void init_hamming_dispatch_tables (void )
103- {
104- using kernels::hamming_fn_ptr_t ;
105-
106- dpctl_td_ns::DispatchVectorBuilder<hamming_fn_ptr_t , HammingFactory,
119+ dpctl_td_ns::DispatchVectorBuilder<window_fn_ptr_t , factoryT,
107120 dpctl_td_ns::num_types>
108121 contig;
109- contig.populate_dispatch_vector (hamming_dispatch_table);
110-
111- return ;
112- }
113-
114- void init_hamming (py::module_ m)
115- {
116- dpnp::extensions::window::init_hamming_dispatch_tables ();
117-
118- m.def (" _hamming" , &py_hamming, " Call hamming kernel" , py::arg (" sycl_queue" ),
119- py::arg (" result" ), py::arg (" depends" ) = py::list ());
122+ contig.populate_dispatch_vector (window_dispatch_vector);
120123
121124 return ;
122125}
0 commit comments