1515# limitations under the License.
1616
1717import contextlib
18+ import itertools
1819import operator
1920
2021import numpy as np
@@ -223,10 +224,13 @@ def print_options(*args, **kwargs):
223224
224225
225226def _nd_corners (arr_in , edge_items ):
226- arr_ndim = arr_in .ndim
227+ _shape = arr_in .shape
228+ max_shape = 2 * edge_items + 1
229+ if max (_shape ) <= max_shape :
230+ return dpt .asnumpy (arr_in )
227231 res_shape = tuple (
228- 2 * edge_items if arr_in . shape [i ] > 2 * edge_items else arr_in . shape [i ]
229- for i in range (arr_ndim )
232+ max_shape if _shape [i ] > max_shape else _shape [i ]
233+ for i in range (arr_in . ndim )
230234 )
231235
232236 arr_out = dpt .empty (
@@ -236,29 +240,27 @@ def _nd_corners(arr_in, edge_items):
236240 sycl_queue = arr_in .sycl_queue ,
237241 )
238242
243+ blocks = []
244+ for i in range (len (_shape )):
245+ if _shape [i ] > max_shape :
246+ blocks .append (
247+ (
248+ np .s_ [:edge_items ],
249+ np .s_ [- edge_items :],
250+ )
251+ )
252+ else :
253+ blocks .append ((np .s_ [:],))
254+
239255 hev_list = []
240- for corner in range (arr_ndim ** 2 ):
241- slices = ()
242- tmp = bin (corner ).replace ("0b" , "" ).zfill (arr_ndim )
243-
244- for dim in reversed (range (arr_ndim )):
245- if arr_in .shape [dim ] < 2 * edge_items :
246- slices = (np .s_ [:],) + slices
247- else :
248- ind = (- 1 ) ** int (tmp [dim ]) * edge_items
249- if ind < 0 :
250- slices = (np .s_ [- edge_items ::],) + slices
251- else :
252- slices = (np .s_ [:edge_items :],) + slices
256+ for slc in itertools .product (* blocks ):
253257 hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
254- src = arr_in [slices ],
255- dst = arr_out [slices ],
256- sycl_queue = arr_in .sycl_queue ,
258+ src = arr_in [slc ], dst = arr_out [slc ], sycl_queue = arr_in .sycl_queue
257259 )
258260 hev_list .append (hev )
259261
260262 dpctl .SyclEvent .wait_for (hev_list )
261- return arr_out
263+ return dpt . asnumpy ( arr_out )
262264
263265
264266def usm_ndarray_str (
@@ -365,8 +367,7 @@ def usm_ndarray_str(
365367 edge_items = options ["edgeitems" ]
366368
367369 if x .size > threshold :
368- # need edge_items + 1 elements for np.array2string to abbreviate
369- data = dpt .asnumpy (_nd_corners (x , edge_items + 1 ))
370+ data = _nd_corners (x , edge_items )
370371 options ["threshold" ] = 0
371372 else :
372373 data = dpt .asnumpy (x )
0 commit comments