22//
33// Data Parallel Control (dpctl)
44//
5- // Copyright 2020-2022 Intel Corporation
5+ // Copyright 2020-2023 Intel Corporation
66//
77// Licensed under the Apache License, Version 2.0 (the "License");
88// you may not use this file except in compliance with the License.
2929// /
3030// ===----------------------------------------------------------------------===//
3131
32+ #pragma once
3233#include " Python.h"
3334#include " syclinterface/dpctl_data_types.h"
35+ #include " syclinterface/dpctl_sycl_type_casters.hpp"
3436#include < CL/sycl.hpp>
3537
36- int async_dec_ref (DPCTLSyclQueueRef QRef,
37- PyObject **obj_array,
38- size_t obj_array_size,
39- DPCTLSyclEventRef *ERefs,
40- size_t nERefs)
38+ DPCTLSyclEventRef async_dec_ref (DPCTLSyclQueueRef QRef,
39+ PyObject **obj_array,
40+ size_t obj_array_size,
41+ DPCTLSyclEventRef *depERefs,
42+ size_t nDepERefs,
43+ int *status)
4144{
45+ using dpctl::syclinterface::unwrap;
46+ using dpctl::syclinterface::wrap;
4247
43- sycl::queue *q = reinterpret_cast <sycl::queue * >(QRef);
48+ sycl::queue *q = unwrap <sycl::queue>(QRef);
4449
45- std::vector<PyObject *> obj_vec;
46- obj_vec.reserve (obj_array_size);
47- for (size_t obj_id = 0 ; obj_id < obj_array_size; ++obj_id) {
48- obj_vec.push_back (obj_array[obj_id]);
49- }
50+ std::vector<PyObject *> obj_vec (obj_array, obj_array + obj_array_size);
5051
5152 try {
52- q->submit ([&](sycl::handler &cgh) {
53- for (size_t ev_id = 0 ; ev_id < nERefs; ++ev_id) {
54- cgh.depends_on (
55- *(reinterpret_cast <sycl::event *>(ERefs[ev_id])));
53+ sycl::event ht_ev = q->submit ([&](sycl::handler &cgh) {
54+ for (size_t ev_id = 0 ; ev_id < nDepERefs; ++ev_id) {
55+ cgh.depends_on (*(unwrap<sycl::event>(depERefs[ev_id])));
5656 }
5757 cgh.host_task ([obj_array_size, obj_vec]() {
5858 // if the main thread has not finilized the interpreter yet
@@ -66,9 +66,21 @@ int async_dec_ref(DPCTLSyclQueueRef QRef,
6666 }
6767 });
6868 });
69+
70+ constexpr int result_ok = 0 ;
71+
72+ *status = result_ok;
73+ auto e_ptr = new sycl::event (ht_ev);
74+ return wrap<sycl::event>(e_ptr);
6975 } catch (const std::exception &e) {
70- return 1 ;
76+ constexpr int result_std_exception = 1 ;
77+
78+ *status = result_std_exception;
79+ return nullptr ;
7180 }
7281
73- return 0 ;
82+ constexpr int result_other_abnormal = 2 ;
83+
84+ *status = result_other_abnormal;
85+ return nullptr ;
7486}
0 commit comments