@@ -184,32 +184,10 @@ sycl::event add_contig_impl(sycl::queue exec_q,
184184 py::ssize_t res_offset,
185185 const std::vector<sycl::event> &depends = {})
186186{
187- sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
188- cgh.depends_on (depends);
189-
190- size_t lws = 64 ;
191- constexpr unsigned int vec_sz = 4 ;
192- constexpr unsigned int n_vecs = 2 ;
193- const size_t n_groups =
194- ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
195- const auto gws_range = sycl::range<1 >(n_groups * lws);
196- const auto lws_range = sycl::range<1 >(lws);
197-
198- using resTy = typename AddOutputType<argTy1, argTy2>::value_type;
199-
200- const argTy1 *arg1_tp =
201- reinterpret_cast <const argTy1 *>(arg1_p) + arg1_offset;
202- const argTy2 *arg2_tp =
203- reinterpret_cast <const argTy2 *>(arg2_p) + arg2_offset;
204- resTy *res_tp = reinterpret_cast <resTy *>(res_p) + res_offset;
205-
206- cgh.parallel_for <
207- add_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
208- sycl::nd_range<1 >(gws_range, lws_range),
209- AddContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
210- arg1_tp, arg2_tp, res_tp, nelems));
211- });
212- return comp_ev;
187+ return elementwise_common::binary_contig_impl<
188+ argTy1, argTy2, AddOutputType, AddContigFunctor, add_contig_kernel>(
189+ exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p,
190+ res_offset, depends);
213191}
214192
215193template <typename fnT, typename T1, typename T2> struct AddContigFactory
@@ -256,28 +234,11 @@ sycl::event add_strided_impl(sycl::queue exec_q,
256234 const std::vector<sycl::event> &depends,
257235 const std::vector<sycl::event> &additional_depends)
258236{
259- sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
260- cgh.depends_on (depends);
261- cgh.depends_on (additional_depends);
262-
263- using resTy = typename AddOutputType<argTy1, argTy2>::value_type;
264-
265- using IndexerT =
266- typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
267-
268- IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
269- shape_and_strides};
270-
271- const argTy1 *arg1_tp = reinterpret_cast <const argTy1 *>(arg1_p);
272- const argTy2 *arg2_tp = reinterpret_cast <const argTy2 *>(arg2_p);
273- resTy *res_tp = reinterpret_cast <resTy *>(res_p);
274-
275- cgh.parallel_for <
276- add_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
277- {nelems}, AddStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
278- arg1_tp, arg2_tp, res_tp, indexer));
279- });
280- return comp_ev;
237+ return elementwise_common::binary_strided_impl<
238+ argTy1, argTy2, AddOutputType, AddStridedFunctor,
239+ add_strided_strided_kernel>(
240+ exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
241+ arg2_offset, res_p, res_offset, depends, additional_depends);
281242}
282243
283244template <typename fnT, typename T1, typename T2> struct AddStridedFactory
@@ -322,62 +283,11 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl(
322283 py::ssize_t res_offset,
323284 const std::vector<sycl::event> &depends = {})
324285{
325- const argT1 *mat = reinterpret_cast <const argT1 *>(mat_p) + mat_offset;
326- const argT2 *vec = reinterpret_cast <const argT2 *>(vec_p) + vec_offset;
327- resT *res = reinterpret_cast <resT *>(res_p) + res_offset;
328-
329- const auto &dev = exec_q.get_device ();
330- const auto &sg_sizes = dev.get_info <sycl::info::device::sub_group_sizes>();
331- // Get device-specific kernel info max_sub_group_size
332- size_t max_sgSize =
333- *(std::max_element (std::begin (sg_sizes), std::end (sg_sizes)));
334-
335- size_t n1_padded = n1 + max_sgSize;
336- argT2 *padded_vec = sycl::malloc_device<argT2>(n1_padded, exec_q);
337-
338- if (padded_vec == nullptr ) {
339- throw std::runtime_error (" Could not allocate memory on the device" );
340- }
341- sycl::event make_padded_vec_ev = exec_q.submit ([&](sycl::handler &cgh) {
342- cgh.depends_on (depends); // ensure vec contains actual data
343- cgh.parallel_for ({n1_padded}, [=](sycl::id<1 > id) {
344- auto i = id[0 ];
345- padded_vec[i] = vec[i % n1];
346- });
347- });
348-
349- // sub-group spans work-items [I, I + sgSize)
350- // base = ndit.get_global_linear_id() - sg.get_local_id()[0]
351- // Generically, sg.load( &mat[base]) may load arrays from
352- // different rows of mat. The start corresponds to row (base / n0)
353- // We read sg.load(&padded_vec[(base / n0)]). The vector is padded to
354- // ensure that reads are accessible
355-
356- size_t lws = 64 ;
357-
358- sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
359- cgh.depends_on (make_padded_vec_ev);
360-
361- auto lwsRange = sycl::range<1 >(lws);
362- size_t n_elems = n0 * n1;
363- size_t n_groups = (n_elems + lws - 1 ) / lws;
364- auto gwsRange = sycl::range<1 >(n_groups * lws);
365-
366- cgh.parallel_for <
367- class add_matrix_row_broadcast_sg_krn <argT1, argT2, resT>>(
368- sycl::nd_range<1 >(gwsRange, lwsRange),
369- AddContigMatrixContigRowBroadcastingFunctor<argT1, argT2, resT>(
370- mat, padded_vec, res, n_elems, n1));
371- });
372-
373- sycl::event tmp_cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
374- cgh.depends_on (comp_ev);
375- sycl::context ctx = exec_q.get_context ();
376- cgh.host_task ([ctx, padded_vec]() { sycl::free (padded_vec, ctx); });
377- });
378- host_tasks.push_back (tmp_cleanup_ev);
379-
380- return comp_ev;
286+ return elementwise_common::binary_contig_matrix_contig_row_broadcast_impl<
287+ argT1, argT2, resT, AddContigMatrixContigRowBroadcastingFunctor,
288+ add_matrix_row_broadcast_sg_krn>(exec_q, host_tasks, n0, n1, mat_p,
289+ mat_offset, vec_p, vec_offset, res_p,
290+ res_offset, depends);
381291}
382292
383293template <typename fnT, typename T1, typename T2>
0 commit comments