@@ -149,6 +149,19 @@ cdef bint _is_host_cpu(object dl_device):
149149 return (dl_type == DLDeviceType.kDLCPU) and (dl_id == 0 )
150150
151151
152+ cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue) except * :
153+ if (stream is None or stream == self_queue):
154+ pass
155+ else :
156+ if not isinstance (stream, dpctl.SyclQueue):
157+ raise TypeError (
158+ " stream argument type was expected to be dpctl.SyclQueue,"
159+ f" got {type(stream)} instead"
160+ )
161+ ev = self_queue.submit_barrier()
162+ stream.submit_barrier(dependent_events = [ev])
163+
164+
152165cdef class usm_ndarray:
153166 """ usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
154167 offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -1025,12 +1038,7 @@ cdef class usm_ndarray:
10251038 cdef c_dpmem._Memory arr_buf
10261039 d = Device.create_device(target_device)
10271040
1028- if (stream is None or not isinstance (stream, dpctl.SyclQueue) or
1029- stream == self .sycl_queue):
1030- pass
1031- else :
1032- ev = self .sycl_queue.submit_barrier()
1033- stream.submit_barrier(dependent_events = [ev])
1041+ _validate_and_use_stream(stream, self .sycl_queue)
10341042
10351043 if (d.sycl_context == self .sycl_context):
10361044 arr_buf = < c_dpmem._Memory> self .usm_data
@@ -1203,12 +1211,7 @@ cdef class usm_ndarray:
12031211 # legacy path for DLManagedTensor
12041212 # copy kwarg ignored because copy flag can't be set
12051213 _caps = c_dlpack.to_dlpack_capsule(self )
1206- if (stream is None or type (stream) is not dpctl.SyclQueue or
1207- stream == self .sycl_queue):
1208- pass
1209- else :
1210- ev = self .sycl_queue.submit_barrier()
1211- stream.submit_barrier(dependent_events = [ev])
1214+ _validate_and_use_stream(stream, self .sycl_queue)
12121215 return _caps
12131216 else :
12141217 if not isinstance (max_version, tuple ) or len (max_version) != 2 :
@@ -1250,12 +1253,7 @@ cdef class usm_ndarray:
12501253 copy = False
12511254 # TODO: strategy for handling stream on different device from dl_device
12521255 if copy:
1253- if (stream is None or type (stream) is not dpctl.SyclQueue or
1254- stream == self .sycl_queue):
1255- pass
1256- else :
1257- ev = self .sycl_queue.submit_barrier()
1258- stream.submit_barrier(dependent_events = [ev])
1256+ _validate_and_use_stream(stream, self .sycl_queue)
12591257 nbytes = self .usm_data.nbytes
12601258 copy_buffer = type (self .usm_data)(
12611259 nbytes, queue = self .sycl_queue
@@ -1272,22 +1270,12 @@ cdef class usm_ndarray:
12721270 _caps = c_dlpack.to_dlpack_versioned_capsule(_copied_arr, copy)
12731271 else :
12741272 _caps = c_dlpack.to_dlpack_versioned_capsule(self , copy)
1275- if (stream is None or type (stream) is not dpctl.SyclQueue or
1276- stream == self .sycl_queue):
1277- pass
1278- else :
1279- ev = self .sycl_queue.submit_barrier()
1280- stream.submit_barrier(dependent_events = [ev])
1273+ _validate_and_use_stream(stream, self .sycl_queue)
12811274 return _caps
12821275 else :
12831276 # legacy path for DLManagedTensor
12841277 _caps = c_dlpack.to_dlpack_capsule(self )
1285- if (stream is None or type (stream) is not dpctl.SyclQueue or
1286- stream == self .sycl_queue):
1287- pass
1288- else :
1289- ev = self .sycl_queue.submit_barrier()
1290- stream.submit_barrier(dependent_events = [ev])
1278+ _validate_and_use_stream(stream, self .sycl_queue)
12911279 return _caps
12921280
12931281 def __dlpack_device__ (self ):
@@ -1555,17 +1543,17 @@ cdef class usm_ndarray:
15551543 def __array__ (self , dtype = None , /, *, copy = None ):
15561544 """ NumPy's array protocol method to disallow implicit conversion.
15571545
1558- Without this definition, `numpy.asarray(usm_ar)` converts
1559- usm_ndarray instance into NumPy array with data type `object`
1560- and every element being 0d usm_ndarray.
1546+ Without this definition, `numpy.asarray(usm_ar)` converts
1547+ usm_ndarray instance into NumPy array with data type `object`
1548+ and every element being 0d usm_ndarray.
15611549
15621550 https://github.com/IntelPython/dpctl/pull/1384#issuecomment-1707212972
1563- """
1551+ """
15641552 raise TypeError (
15651553 " Implicit conversion to a NumPy array is not allowed. "
1566- " Use `dpctl.tensor.asnumpy` to copy data from this "
1567- " `dpctl.tensor.usm_ndarray` instance to NumPy array"
1568- )
1554+ " Use `dpctl.tensor.asnumpy` to copy data from this "
1555+ " `dpctl.tensor.usm_ndarray` instance to NumPy array"
1556+ )
15691557
15701558
15711559cdef usm_ndarray _real_view(usm_ndarray ary):
0 commit comments