@@ -65,7 +65,15 @@ import ctypes
6565from .enum_types import backend_type
6666
6767from cpython cimport pycapsule
68- from cpython.buffer cimport PyObject_CheckBuffer
68+ from cpython.buffer cimport (
69+ Py_buffer,
70+ PyBUF_ANY_CONTIGUOUS,
71+ PyBUF_SIMPLE,
72+ PyBUF_WRITABLE,
73+ PyBuffer_Release,
74+ PyObject_CheckBuffer,
75+ PyObject_GetBuffer,
76+ )
6977from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject
7078from libc.stdlib cimport free, malloc
7179
@@ -338,14 +346,20 @@ cdef DPCTLSyclEventRef _memcpy_impl(
338346 cdef void * c_dst_ptr = NULL
339347 cdef void * c_src_ptr = NULL
340348 cdef DPCTLSyclEventRef ERef = NULL
341- cdef const unsigned char [::1 ] src_host_buf = None
342- cdef unsigned char [::1 ] dst_host_buf = None
349+ cdef Py_buffer src_buf_view
350+ cdef Py_buffer dst_buf_view
351+ cdef bint src_is_buf = False
352+ cdef bint dst_is_buf = False
353+ cdef int ret_code = 0
343354
344355 if isinstance (src, _Memory):
345356 c_src_ptr = < void * > (< _Memory> src).get_data_ptr()
346357 elif _is_buffer(src):
347- src_host_buf = src
348- c_src_ptr = < void * > & src_host_buf[0 ]
358+ ret_code = PyObject_GetBuffer(src, & src_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
359+ if ret_code != 0 : # pragma: no cover
360+ raise RuntimeError (" Could not access buffer" )
361+ c_src_ptr = src_buf_view.buf
362+ src_is_buf = True
349363 else :
350364 raise TypeError (
351365 " Parameter `src` should have either type "
@@ -356,8 +370,13 @@ cdef DPCTLSyclEventRef _memcpy_impl(
356370 if isinstance (dst, _Memory):
357371 c_dst_ptr = < void * > (< _Memory> dst).get_data_ptr()
358372 elif _is_buffer(dst):
359- dst_host_buf = dst
360- c_dst_ptr = < void * > & dst_host_buf[0 ]
373+ ret_code = PyObject_GetBuffer(dst, & dst_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | PyBUF_WRITABLE)
374+ if ret_code != 0 : # pragma: no cover
375+ if src_is_buf:
376+ PyBuffer_Release(& src_buf_view)
377+ raise RuntimeError (" Could not access buffer" )
378+ c_dst_ptr = dst_buf_view.buf
379+ dst_is_buf = True
361380 else :
362381 raise TypeError (
363382 " Parameter `dst` should have either type "
@@ -376,6 +395,12 @@ cdef DPCTLSyclEventRef _memcpy_impl(
376395 dep_events,
377396 dep_events_count
378397 )
398+
399+ if src_is_buf:
400+ PyBuffer_Release(& src_buf_view)
401+ if dst_is_buf:
402+ PyBuffer_Release(& dst_buf_view)
403+
379404 return ERef
380405
381406
0 commit comments