@@ -149,6 +149,19 @@ cdef bint _is_host_cpu(object dl_device):
149149 return (dl_type == DLDeviceType.kDLCPU) and (dl_id == 0 )
150150
151151
152+ cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue) except * :
153+ if (stream is None or stream == self_queue):
154+ pass
155+ else :
156+ if not isinstance (stream, dpctl.SyclQueue):
157+ raise TypeError (
158+ " stream argument type was expected to be dpctl.SyclQueue,"
159+ f" got {type(stream)} instead"
160+ )
161+ ev = self_queue.submit_barrier()
162+ stream.submit_barrier(dependent_events = [ev])
163+
164+
152165cdef class usm_ndarray:
153166 """ usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
154167 offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -1025,16 +1038,7 @@ cdef class usm_ndarray:
10251038 cdef c_dpmem._Memory arr_buf
10261039 d = Device.create_device(target_device)
10271040
1028- if (stream is None or stream == self .sycl_queue):
1029- pass
1030- else :
1031- if not isinstance (stream, dpctl.SyclQueue):
1032- raise TypeError (
1033- " stream argument type was expected to be dpctl.SyclQueue,"
1034- f" got {type(stream)} instead"
1035- )
1036- ev = self .sycl_queue.submit_barrier()
1037- stream.submit_barrier(dependent_events = [ev])
1041+ _validate_and_use_stream(stream, self .sycl_queue)
10381042
10391043 if (d.sycl_context == self .sycl_context):
10401044 arr_buf = < c_dpmem._Memory> self .usm_data
@@ -1207,17 +1211,7 @@ cdef class usm_ndarray:
12071211 # legacy path for DLManagedTensor
12081212 # copy kwarg ignored because copy flag can't be set
12091213 _caps = c_dlpack.to_dlpack_capsule(self )
1210- if (stream is None or stream == self .sycl_queue):
1211- pass
1212- else :
1213- if not isinstance (stream, dpctl.SyclQueue):
1214- raise TypeError (
1215- " stream keyword argument type is expected to "
1216- " be dpctl.SyclQueue, "
1217- f" got {type(stream)} instead"
1218- )
1219- ev = self .sycl_queue.submit_barrier()
1220- stream.submit_barrier(dependent_events = [ev])
1214+ _validate_and_use_stream(stream, self .sycl_queue)
12211215 return _caps
12221216 else :
12231217 if not isinstance (max_version, tuple ) or len (max_version) != 2 :
@@ -1259,12 +1253,7 @@ cdef class usm_ndarray:
12591253 copy = False
12601254 # TODO: strategy for handling stream on different device from dl_device
12611255 if copy:
1262- if (stream is None or type (stream) is not dpctl.SyclQueue or
1263- stream == self .sycl_queue):
1264- pass
1265- else :
1266- ev = self .sycl_queue.submit_barrier()
1267- stream.submit_barrier(dependent_events = [ev])
1256+ _validate_and_use_stream(stream, self .sycl_queue)
12681257 nbytes = self .usm_data.nbytes
12691258 copy_buffer = type (self .usm_data)(
12701259 nbytes, queue = self .sycl_queue
@@ -1281,22 +1270,12 @@ cdef class usm_ndarray:
12811270 _caps = c_dlpack.to_dlpack_versioned_capsule(_copied_arr, copy)
12821271 else :
12831272 _caps = c_dlpack.to_dlpack_versioned_capsule(self , copy)
1284- if (stream is None or type (stream) is not dpctl.SyclQueue or
1285- stream == self .sycl_queue):
1286- pass
1287- else :
1288- ev = self .sycl_queue.submit_barrier()
1289- stream.submit_barrier(dependent_events = [ev])
1273+ _validate_and_use_stream(stream, self .sycl_queue)
12901274 return _caps
12911275 else :
12921276 # legacy path for DLManagedTensor
12931277 _caps = c_dlpack.to_dlpack_capsule(self )
1294- if (stream is None or type (stream) is not dpctl.SyclQueue or
1295- stream == self .sycl_queue):
1296- pass
1297- else :
1298- ev = self .sycl_queue.submit_barrier()
1299- stream.submit_barrier(dependent_events = [ev])
1278+ _validate_and_use_stream(stream, self .sycl_queue)
13001279 return _caps
13011280
13021281 def __dlpack_device__ (self ):
0 commit comments