@@ -167,32 +167,10 @@ sycl::event not_equal_contig_impl(sycl::queue exec_q,
167167 py::ssize_t res_offset,
168168 const std::vector<sycl::event> &depends = {})
169169{
170- sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
171- cgh.depends_on (depends);
172-
173- size_t lws = 64 ;
174- constexpr unsigned int vec_sz = 4 ;
175- constexpr unsigned int n_vecs = 2 ;
176- const size_t n_groups =
177- ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
178- const auto gws_range = sycl::range<1 >(n_groups * lws);
179- const auto lws_range = sycl::range<1 >(lws);
180-
181- using resTy = typename NotEqualOutputType<argTy1, argTy2>::value_type;
182-
183- const argTy1 *arg1_tp =
184- reinterpret_cast <const argTy1 *>(arg1_p) + arg1_offset;
185- const argTy2 *arg2_tp =
186- reinterpret_cast <const argTy2 *>(arg2_p) + arg2_offset;
187- resTy *res_tp = reinterpret_cast <resTy *>(res_p) + res_offset;
188-
189- cgh.parallel_for <
190- not_equal_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
191- sycl::nd_range<1 >(gws_range, lws_range),
192- NotEqualContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
193- arg1_tp, arg2_tp, res_tp, nelems));
194- });
195- return comp_ev;
170+ return elementwise_common::binary_contig_impl<
171+ argTy1, argTy2, NotEqualOutputType, NotEqualContigFunctor,
172+ not_equal_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
173+ arg2_offset, res_p, res_offset, depends);
196174}
197175
198176template <typename fnT, typename T1, typename T2> struct NotEqualContigFactory
@@ -215,7 +193,7 @@ template <typename fnT, typename T1, typename T2> struct NotEqualContigFactory
215193
216194template <typename fnT, typename T1, typename T2> struct NotEqualTypeMapFactory
217195{
218- /* ! @brief get typeid for output type of operator()= =(x, y), always bool */
196+ /* ! @brief get typeid for output type of operator()! =(x, y), always bool */
219197 std::enable_if_t <std::is_same<fnT, int >::value, int > get ()
220198 {
221199 using rT = typename NotEqualOutputType<T1, T2>::value_type;
@@ -241,28 +219,11 @@ not_equal_strided_impl(sycl::queue exec_q,
241219 const std::vector<sycl::event> &depends,
242220 const std::vector<sycl::event> &additional_depends)
243221{
244- sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
245- cgh.depends_on (depends);
246- cgh.depends_on (additional_depends);
247-
248- using resTy = typename NotEqualOutputType<argTy1, argTy2>::value_type;
249-
250- using IndexerT =
251- typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
252-
253- IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
254- shape_and_strides};
255-
256- const argTy1 *arg1_tp = reinterpret_cast <const argTy1 *>(arg1_p);
257- const argTy2 *arg2_tp = reinterpret_cast <const argTy2 *>(arg2_p);
258- resTy *res_tp = reinterpret_cast <resTy *>(res_p);
259-
260- cgh.parallel_for <
261- not_equal_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
262- {nelems}, NotEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
263- arg1_tp, arg2_tp, res_tp, indexer));
264- });
265- return comp_ev;
222+ return elementwise_common::binary_strided_impl<
223+ argTy1, argTy2, NotEqualOutputType, NotEqualStridedFunctor,
224+ not_equal_strided_strided_kernel>(
225+ exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
226+ arg2_offset, res_p, res_offset, depends, additional_depends);
266227}
267228
268229template <typename fnT, typename T1, typename T2> struct NotEqualStridedFactory
0 commit comments