@@ -281,6 +281,37 @@ __cached_notinplace_@DftiCompute_MODE@_@MKL_IN_TYPE@_@MKL_OUT_TYPE@(
281281}
282282/**end repeat**/
283283
284+ inline npy_intp
285+ compute_distance(npy_intp *x_strides, npy_intp *x_shape, npy_intp x_itemsize, int x_rank, int i1, int i2) {
286+ npy_intp st1, st2;
287+ npy_intp sh1 = x_shape[i1], sh2 = x_shape[i2];
288+ npy_intp min_s;
289+ if (sh1 > 1 && sh2 > 1) {
290+ st1 = x_strides[i1];
291+ st2 = x_strides[i2];
292+ min_s = (st1 > st2) ? st2 : st1;
293+
294+ return min_s;
295+ } else {
296+ int i;
297+ npy_intp max_s;
298+ max_s = x_itemsize;
299+ for(i=0; i < x_rank; i++) {
300+ if (x_shape[i] > 1) {
301+ if (max_s < x_strides[i]) max_s = x_strides[i];
302+ }
303+ }
304+ min_s = max_s;
305+ for(i=i1; i <= i2; i++) {
306+ if (x_shape[i] > 1) {
307+ if (min_s > x_strides[i]) min_s = x_strides[i];
308+ }
309+ }
310+ }
311+
312+ return min_s;
313+ }
314+
284315static NPY_INLINE int
285316compute_strides_and_distances(
286317 PyArrayObject *x,
@@ -315,11 +346,9 @@ compute_strides_and_distances(
315346 npy_intp char_dist = 0;
316347 *num_fft_transfs = _to_mkl_long (x_size / x_shape[axis]);
317348 if (axis == 0) {
318- npy_intp s1 = x_strides[1], s2 = x_strides[x_rank-1];
319- char_dist = (s1 > s2) ? s2 : s1;
349+ char_dist = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 1, x_rank-1);
320350 } else {
321- npy_intp s1 = x_strides[0], s2 = x_strides[x_rank-2];
322- char_dist = (s1 > s2) ? s2 : s1;
351+ char_dist = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 0, x_rank-2);
323352 }
324353
325354 *vec_dist = _to_mkl_long (char_dist / x_itemsize);
@@ -375,17 +404,11 @@ compute_strides_and_distances_inout(
375404 npy_intp char_dist_in = 0, char_dist_out = 0;
376405 *num_fft_transfs = _to_mkl_long (x_size / x_shape[axis]);
377406 if (axis == 0) {
378- npy_intp s1 = x_strides[1], s2 = x_strides[x_rank-1];
379- char_dist_in = (s1 > s2) ? s2 : s1;
380-
381- s1 = y_strides[1]; s2 = y_strides[x_rank-1];
382- char_dist_out = (s1 > s2) ? s2 : s1;
407+ char_dist_in = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 1, x_rank-1);
408+ char_dist_out = compute_distance(y_strides, y_shape, y_itemsize, x_rank, 1, x_rank-1);
383409 } else {
384- npy_intp s1 = x_strides[0], s2 = x_strides[x_rank-2];
385- char_dist_in = (s1 > s2) ? s2 : s1;
386-
387- s1 = y_strides[0]; s2 = y_strides[x_rank-2];
388- char_dist_out = (s1 > s2) ? s2 : s1;
410+ char_dist_in = compute_distance(x_strides, x_shape, x_itemsize, x_rank, 0, x_rank-2);
411+ char_dist_out = compute_distance(y_strides, y_shape, y_itemsize, x_rank, 0, x_rank-2);
389412 }
390413 *vec_dist_in = _to_mkl_long (char_dist_in / x_itemsize);
391414 *vec_dist_out = _to_mkl_long (char_dist_out / y_itemsize);
0 commit comments