|
34 | 34 | #include <vector> |
35 | 35 |
|
36 | 36 | #include "simplify_iteration_space.hpp" |
| 37 | +#include "utils/memory_overlap.hpp" |
37 | 38 | #include "utils/offset_utils.hpp" |
38 | 39 | #include "utils/type_dispatch.hpp" |
39 | 40 |
|
@@ -122,23 +123,14 @@ py_unary_ufunc(dpctl::tensor::usm_ndarray src, |
122 | 123 | } |
123 | 124 |
|
124 | 125 | // check memory overlap |
125 | | - const char *src_data = src.get_data(); |
126 | | - char *dst_data = dst.get_data(); |
127 | | - |
128 | | - // check that arrays do not overlap, and concurrent copying is safe. |
129 | | - auto src_offsets = src.get_minmax_offsets(); |
130 | | - int src_elem_size = src.get_elemsize(); |
131 | | - int dst_elem_size = dst.get_elemsize(); |
132 | | - |
133 | | - bool memory_overlap = |
134 | | - ((dst_data - src_data > src_offsets.second * src_elem_size - |
135 | | - dst_offsets.first * dst_elem_size) && |
136 | | - (src_data - dst_data > dst_offsets.second * dst_elem_size - |
137 | | - src_offsets.first * src_elem_size)); |
138 | | - if (memory_overlap) { |
| 126 | + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); |
| 127 | + if (overlap(src, dst)) { |
139 | 128 | throw py::value_error("Arrays index overlapping segments of memory"); |
140 | 129 | } |
141 | 130 |
|
| 131 | + const char *src_data = src.get_data(); |
| 132 | + char *dst_data = dst.get_data(); |
| 133 | + |
142 | 134 | // handle contiguous inputs |
143 | 135 | bool is_src_c_contig = src.is_c_contiguous(); |
144 | 136 | bool is_src_f_contig = src.is_f_contiguous(); |
@@ -378,32 +370,16 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc( |
378 | 370 | } |
379 | 371 | } |
380 | 372 |
|
| 373 | + // check memory overlap |
| 374 | + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); |
| 375 | + if (overlap(src1, dst) || overlap(src2, dst)) { |
| 376 | + throw py::value_error("Arrays index overlapping segments of memory"); |
| 377 | + } |
381 | 378 | // check memory overlap |
382 | 379 | const char *src1_data = src1.get_data(); |
383 | 380 | const char *src2_data = src2.get_data(); |
384 | 381 | char *dst_data = dst.get_data(); |
385 | 382 |
|
386 | | - // check that arrays do not overlap, and concurrent copying is safe. |
387 | | - auto src1_offsets = src1.get_minmax_offsets(); |
388 | | - int src1_elem_size = src1.get_elemsize(); |
389 | | - auto src2_offsets = src2.get_minmax_offsets(); |
390 | | - int src2_elem_size = src2.get_elemsize(); |
391 | | - int dst_elem_size = dst.get_elemsize(); |
392 | | - |
393 | | - bool memory_overlap_src1_dst = |
394 | | - ((dst_data - src1_data > src1_offsets.second * src1_elem_size - |
395 | | - dst_offsets.first * dst_elem_size) && |
396 | | - (src1_data - dst_data > dst_offsets.second * dst_elem_size - |
397 | | - src1_offsets.first * src1_elem_size)); |
398 | | - bool memory_overlap_src2_dst = |
399 | | - ((dst_data - src2_data > src2_offsets.second * src2_elem_size - |
400 | | - dst_offsets.first * dst_elem_size) && |
401 | | - (src2_data - dst_data > dst_offsets.second * dst_elem_size - |
402 | | - src2_offsets.first * src2_elem_size)); |
403 | | - if (memory_overlap_src1_dst || memory_overlap_src2_dst) { |
404 | | - throw py::value_error("Arrays index overlapping segments of memory"); |
405 | | - } |
406 | | - |
407 | 383 | // handle contiguous inputs |
408 | 384 | bool is_src1_c_contig = src1.is_c_contiguous(); |
409 | 385 | bool is_src1_f_contig = src1.is_f_contiguous(); |
|
0 commit comments