22//
33// Data Parallel Control (dpctl)
44//
5- // Copyright 2020-2022 Intel Corporation
5+ // Copyright 2020-2023 Intel Corporation
66//
77// Licensed under the Apache License, Version 2.0 (the "License");
88// you may not use this file except in compliance with the License.
@@ -57,31 +57,24 @@ class WhereContigFunctor
5757{
5858private:
5959 size_t nelems = 0 ;
60- const char *x1_cp = nullptr ;
61- const char *x2_cp = nullptr ;
62- char *dst_cp = nullptr ;
63- const char *cond_cp = nullptr ;
60+ const condT *cond_p = nullptr ;
61+ const T *x1_p = nullptr ;
62+ const T *x2_p = nullptr ;
63+ T *dst_p = nullptr ;
6464
6565public:
6666 WhereContigFunctor (size_t nelems_,
67- const char *cond_data_p ,
68- const char *x1_data_p ,
69- const char *x2_data_p ,
70- char *dst_data_p )
71- : nelems(nelems_), x1_cp(x1_data_p ), x2_cp(x2_data_p ),
72- dst_cp (dst_data_p), cond_cp(cond_data_p )
67+ const condT *cond_p_ ,
68+ const T *x1_p_ ,
69+ const T *x2_p_ ,
70+ T *dst_p_ )
71+ : nelems(nelems_), cond_p(cond_p_ ), x1_p(x1_p_), x2_p(x2_p_ ),
72+ dst_p (dst_p_ )
7373 {
7474 }
7575
7676 void operator ()(sycl::nd_item<1 > ndit) const
7777 {
78- const T *x1_data = reinterpret_cast <const T *>(x1_cp);
79- const T *x2_data = reinterpret_cast <const T *>(x2_cp);
80- T *dst_data = reinterpret_cast <T *>(dst_cp);
81- const condT *cond_data = reinterpret_cast <const condT *>(cond_cp);
82-
83- using dpctl::tensor::type_utils::convert_impl;
84-
8578 using dpctl::tensor::type_utils::is_complex;
8679 if constexpr (is_complex<condT>::value || is_complex<T>::value) {
8780 std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
@@ -92,8 +85,9 @@ class WhereContigFunctor
9285 offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
9386 offset += sgSize)
9487 {
95- bool check = convert_impl<bool , condT>(cond_data[offset]);
96- dst_data[offset] = check ? x1_data[offset] : x2_data[offset];
88+ using dpctl::tensor::type_utils::convert_impl;
89+ bool check = convert_impl<bool , condT>(cond_p[offset]);
90+ dst_p[offset] = check ? x1_p[offset] : x2_p[offset];
9791 }
9892 }
9993 else {
@@ -115,7 +109,6 @@ class WhereContigFunctor
115109 using cond_ptrT =
116110 sycl::multi_ptr<const condT,
117111 sycl::access::address_space::global_space>;
118-
119112 sycl::vec<T, vec_sz> dst_vec;
120113 sycl::vec<T, vec_sz> x1_vec;
121114 sycl::vec<T, vec_sz> x2_vec;
@@ -124,23 +117,20 @@ class WhereContigFunctor
124117#pragma unroll
125118 for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
126119 auto idx = base + it * sgSize;
127- x1_vec = sg.load <vec_sz>(x_ptrT (&x1_data[idx]));
128- x2_vec = sg.load <vec_sz>(x_ptrT (&x2_data[idx]));
129- cond_vec = sg.load <vec_sz>(cond_ptrT (&cond_data[idx]));
130-
120+ x1_vec = sg.load <vec_sz>(x_ptrT (&x1_p[idx]));
121+ x2_vec = sg.load <vec_sz>(x_ptrT (&x2_p[idx]));
122+ cond_vec = sg.load <vec_sz>(cond_ptrT (&cond_p[idx]));
131123#pragma unroll
132124 for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
133- bool check = convert_impl<bool , condT>(cond_vec[k]);
134- dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
125+ dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k];
135126 }
136- sg.store <vec_sz>(dst_ptrT (&dst_data [idx]), dst_vec);
127+ sg.store <vec_sz>(dst_ptrT (&dst_p [idx]), dst_vec);
137128 }
138129 }
139130 else {
140131 for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems;
141132 k += sgSize) {
142- bool check = convert_impl<bool , condT>(cond_data[k]);
143- dst_data[k] = check ? x1_data[k] : x2_data[k];
133+ dst_p[k] = cond_p[k] ? x1_p[k] : x2_p[k];
144134 }
145135 }
146136 }
@@ -159,12 +149,17 @@ typedef sycl::event (*where_contig_impl_fn_ptr_t)(
159149template <typename T, typename condT>
160150sycl::event where_contig_impl (sycl::queue q,
161151 size_t nelems,
162- const char *cond_p ,
163- const char *x1_p ,
164- const char *x2_p ,
165- char *dst_p ,
152+ const char *cond_cp ,
153+ const char *x1_cp ,
154+ const char *x2_cp ,
155+ char *dst_cp ,
166156 const std::vector<sycl::event> &depends)
167157{
158+ const condT *cond_tp = reinterpret_cast <const condT *>(cond_cp);
159+ const T *x1_tp = reinterpret_cast <const T *>(x1_cp);
160+ const T *x2_tp = reinterpret_cast <const T *>(x2_cp);
161+ T *dst_tp = reinterpret_cast <T *>(dst_cp);
162+
168163 sycl::event where_ev = q.submit ([&](sycl::handler &cgh) {
169164 cgh.depends_on (depends);
170165
@@ -178,8 +173,8 @@ sycl::event where_contig_impl(sycl::queue q,
178173
179174 cgh.parallel_for <where_contig_kernel<T, condT, vec_sz, n_vecs>>(
180175 sycl::nd_range<1 >(gws_range, lws_range),
181- WhereContigFunctor<T, condT, vec_sz, n_vecs>(nelems, cond_p, x1_p ,
182- x2_p, dst_p ));
176+ WhereContigFunctor<T, condT, vec_sz, n_vecs>(nelems, cond_tp, x1_tp ,
177+ x2_tp, dst_tp ));
183178 });
184179
185180 return where_ev;
@@ -189,39 +184,34 @@ template <typename T, typename condT, typename IndexerT>
189184class WhereStridedFunctor
190185{
191186private:
192- const char *x1_cp = nullptr ;
193- const char *x2_cp = nullptr ;
194- char *dst_cp = nullptr ;
195- const char *cond_cp = nullptr ;
187+ const T *x1_p = nullptr ;
188+ const T *x2_p = nullptr ;
189+ T *dst_p = nullptr ;
190+ const condT *cond_p = nullptr ;
196191 IndexerT indexer;
197192
198193public:
199- WhereStridedFunctor (const char *cond_data_p ,
200- const char *x1_data_p ,
201- const char *x2_data_p ,
202- char *dst_data_p ,
194+ WhereStridedFunctor (const condT *cond_p_ ,
195+ const T *x1_p_ ,
196+ const T *x2_p_ ,
197+ T *dst_p_ ,
203198 IndexerT indexer_)
204- : x1_cp(x1_data_p ), x2_cp(x2_data_p ), dst_cp(dst_data_p ),
205- cond_cp (cond_data_p), indexer(indexer_)
199+ : x1_p(x1_p_ ), x2_p(x2_p_ ), dst_p(dst_p_), cond_p(cond_p_ ),
200+ indexer (indexer_)
206201 {
207202 }
208203
209204 void operator ()(sycl::id<1 > id) const
210205 {
211- const T *x1_data = reinterpret_cast <const T *>(x1_cp);
212- const T *x2_data = reinterpret_cast <const T *>(x2_cp);
213- T *dst_data = reinterpret_cast <T *>(dst_cp);
214- const condT *cond_data = reinterpret_cast <const condT *>(cond_cp);
215-
216206 size_t gid = id[0 ];
217207 auto offsets = indexer (static_cast <py::ssize_t >(gid));
218208
219209 using dpctl::tensor::type_utils::convert_impl;
220210 bool check =
221- convert_impl<bool , condT>(cond_data [offsets.get_first_offset ()]);
211+ convert_impl<bool , condT>(cond_p [offsets.get_first_offset ()]);
222212
223- dst_data [gid] = check ? x1_data [offsets.get_second_offset ()]
224- : x2_data [offsets.get_third_offset ()];
213+ dst_p [gid] = check ? x1_p [offsets.get_second_offset ()]
214+ : x2_p [offsets.get_third_offset ()];
225215 }
226216};
227217
@@ -243,16 +233,21 @@ template <typename T, typename condT>
243233sycl::event where_strided_impl (sycl::queue q,
244234 size_t nelems,
245235 int nd,
246- const char *cond_p ,
247- const char *x1_p ,
248- const char *x2_p ,
249- char *dst_p ,
236+ const char *cond_cp ,
237+ const char *x1_cp ,
238+ const char *x2_cp ,
239+ char *dst_cp ,
250240 const py::ssize_t *shape_strides,
251241 py::ssize_t x1_offset,
252242 py::ssize_t x2_offset,
253243 py::ssize_t cond_offset,
254244 const std::vector<sycl::event> &depends)
255245{
246+ const condT *cond_tp = reinterpret_cast <const condT *>(cond_cp);
247+ const T *x1_tp = reinterpret_cast <const T *>(x1_cp);
248+ const T *x2_tp = reinterpret_cast <const T *>(x2_cp);
249+ T *dst_tp = reinterpret_cast <T *>(dst_cp);
250+
256251 sycl::event where_ev = q.submit ([&](sycl::handler &cgh) {
257252 cgh.depends_on (depends);
258253
@@ -263,7 +258,7 @@ sycl::event where_strided_impl(sycl::queue q,
263258 where_strided_kernel<T, condT, ThreeOffsets_StridedIndexer>>(
264259 sycl::range<1 >(nelems),
265260 WhereStridedFunctor<T, condT, ThreeOffsets_StridedIndexer>(
266- cond_p, x1_p, x2_p, dst_p , indexer));
261+ cond_tp, x1_tp, x2_tp, dst_tp , indexer));
267262 });
268263
269264 return where_ev;
0 commit comments