@@ -71,12 +71,17 @@ void simplify_iteration_space_1(int &nd,
7171 nd = contracted_nd;
7272 }
7373 else if (nd == 1 ) {
74+ offset = 0 ;
7475 // Populate vectors
7576 simplified_shape.reserve (nd);
7677 simplified_shape.push_back (shape[0 ]);
7778
7879 simplified_strides.reserve (nd);
79- simplified_strides.push_back (strides[0 ]);
80+ simplified_strides.push_back ((strides[0 ] >= 0 ) ? strides[0 ]
81+ : -strides[0 ]);
82+ if ((strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
83+ offset += (shape[0 ] - 1 ) * strides[0 ];
84+ }
8085
8186 assert (simplified_shape.size () == static_cast <size_t >(nd));
8287 assert (simplified_strides.size () == static_cast <size_t >(nd));
@@ -128,17 +133,27 @@ void simplify_iteration_space(int &nd,
128133 nd = contracted_nd;
129134 }
130135 else if (nd == 1 ) {
136+ src_offset = 0 ;
137+ dst_offset = 0 ;
131138 // Populate vectors
132139 simplified_shape.reserve (nd);
133140 simplified_shape.push_back (shape[0 ]);
134141 assert (simplified_shape.size () == static_cast <size_t >(nd));
135142
136143 simplified_src_strides.reserve (nd);
137- simplified_src_strides.push_back (src_strides[0 ]);
144+ simplified_src_strides.push_back (
145+ (src_strides[0 ] >= 0 ) ? src_strides[0 ] : -src_strides[0 ]);
146+ if ((src_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
147+ src_offset += (shape[0 ] - 1 ) * src_strides[0 ];
148+ }
138149 assert (simplified_src_strides.size () == static_cast <size_t >(nd));
139150
140151 simplified_dst_strides.reserve (nd);
141- simplified_dst_strides.push_back (dst_strides[0 ]);
152+ simplified_dst_strides.push_back (
153+ (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
154+ if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
155+ dst_offset += (shape[0 ] - 1 ) * dst_strides[0 ];
156+ }
142157 assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
143158 }
144159}
@@ -202,21 +217,36 @@ void simplify_iteration_space_3(
202217 nd = contracted_nd;
203218 }
204219 else if (nd == 1 ) {
220+ src1_offset = 0 ;
221+ src2_offset = 0 ;
222+ dst_offset = 0 ;
205223 // Populate vectors
206224 simplified_shape.reserve (nd);
207225 simplified_shape.push_back (shape[0 ]);
208226 assert (simplified_shape.size () == static_cast <size_t >(nd));
209227
210228 simplified_src1_strides.reserve (nd);
211- simplified_src1_strides.push_back (src1_strides[0 ]);
229+ simplified_src1_strides.push_back (
230+ (src1_strides[0 ] >= 0 ) ? src1_strides[0 ] : -src1_strides[0 ]);
231+ if ((src1_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
232+ src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
233+ }
212234 assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
213235
214236 simplified_src2_strides.reserve (nd);
215- simplified_src2_strides.push_back (src2_strides[0 ]);
237+ simplified_src2_strides.push_back (
238+ (src2_strides[0 ] >= 0 ) ? src2_strides[0 ] : -src2_strides[0 ]);
239+ if ((src2_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
240+ src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
241+ }
216242 assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
217243
218244 simplified_dst_strides.reserve (nd);
219- simplified_dst_strides.push_back (dst_strides[0 ]);
245+ simplified_dst_strides.push_back (
246+ (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
247+ if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
248+ dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
249+ }
220250 assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
221251 }
222252}
@@ -293,29 +323,129 @@ void simplify_iteration_space_4(
293323 nd = contracted_nd;
294324 }
295325 else if (nd == 1 ) {
326+ src1_offset = 0 ;
327+ src2_offset = 0 ;
328+ src3_offset = 0 ;
329+ dst_offset = 0 ;
296330 // Populate vectors
297331 simplified_shape.reserve (nd);
298332 simplified_shape.push_back (shape[0 ]);
299333 assert (simplified_shape.size () == static_cast <size_t >(nd));
300334
301335 simplified_src1_strides.reserve (nd);
302- simplified_src1_strides.push_back (src1_strides[0 ]);
336+ simplified_src1_strides.push_back (
337+ (src1_strides[0 ] >= 0 ) ? src1_strides[0 ] : -src1_strides[0 ]);
338+ if ((src1_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
339+ src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
340+ }
303341 assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
304342
305343 simplified_src2_strides.reserve (nd);
306- simplified_src2_strides.push_back (src2_strides[0 ]);
344+ simplified_src2_strides.push_back (
345+ (src2_strides[0 ] >= 0 ) ? src2_strides[0 ] : -src2_strides[0 ]);
346+ if ((src2_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
347+ src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
348+ }
307349 assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
308350
309351 simplified_src3_strides.reserve (nd);
310- simplified_src3_strides.push_back (src3_strides[0 ]);
352+ simplified_src3_strides.push_back (
353+ (src3_strides[0 ] >= 0 ) ? src3_strides[0 ] : -src3_strides[0 ]);
354+ if ((src3_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
355+ src3_offset += src3_strides[0 ] * (shape[0 ] - 1 );
356+ }
311357 assert (simplified_src3_strides.size () == static_cast <size_t >(nd));
312358
313359 simplified_dst_strides.reserve (nd);
314- simplified_dst_strides.push_back (dst_strides[0 ]);
360+ simplified_dst_strides.push_back (
361+ (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
362+ if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
363+ dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
364+ }
315365 assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
316366 }
317367}
318368
369+ py::ssize_t _ravel_multi_index_c (std::vector<py::ssize_t > const &mi,
370+ std::vector<py::ssize_t > const &shape)
371+ {
372+ size_t nd = shape.size ();
373+ if (nd != mi.size ()) {
374+ throw py::value_error (
375+ " Multi-index and shape vectors must have the same length." );
376+ }
377+
378+ py::ssize_t flat_index = 0 ;
379+ py::ssize_t s = 1 ;
380+ for (size_t i = 0 ; i < nd; ++i) {
381+ flat_index += mi.at (nd - 1 - i) * s;
382+ s *= shape.at (nd - 1 - i);
383+ }
384+
385+ return flat_index;
386+ }
387+
388+ py::ssize_t _ravel_multi_index_f (std::vector<py::ssize_t > const &mi,
389+ std::vector<py::ssize_t > const &shape)
390+ {
391+ size_t nd = shape.size ();
392+ if (nd != mi.size ()) {
393+ throw py::value_error (
394+ " Multi-index and shape vectors must have the same length." );
395+ }
396+
397+ py::ssize_t flat_index = 0 ;
398+ py::ssize_t s = 1 ;
399+ for (size_t i = 0 ; i < nd; ++i) {
400+ flat_index += mi.at (i) * s;
401+ s *= shape.at (i);
402+ }
403+
404+ return flat_index;
405+ }
406+
407+ std::vector<py::ssize_t > _unravel_index_c (py::ssize_t flat_index,
408+ std::vector<py::ssize_t > const &shape)
409+ {
410+ size_t nd = shape.size ();
411+ std::vector<py::ssize_t > mi;
412+ mi.resize (nd);
413+
414+ py::ssize_t i_ = flat_index;
415+ for (size_t dim = 0 ; dim + 1 < nd; ++dim) {
416+ const py::ssize_t si = shape[nd - 1 - dim];
417+ const py::ssize_t q = i_ / si;
418+ const py::ssize_t r = (i_ - q * si);
419+ mi[nd - 1 - dim] = r;
420+ i_ = q;
421+ }
422+ if (nd) {
423+ mi[0 ] = i_;
424+ }
425+ return mi;
426+ }
427+
428+ std::vector<py::ssize_t > _unravel_index_f (py::ssize_t flat_index,
429+ std::vector<py::ssize_t > const &shape)
430+ {
431+ size_t nd = shape.size ();
432+ std::vector<py::ssize_t > mi;
433+ mi.resize (nd);
434+
435+ py::ssize_t i_ = flat_index;
436+ for (size_t dim = 0 ; dim + 1 < nd; ++dim) {
437+ const py::ssize_t si = shape[dim];
438+ const py::ssize_t q = i_ / si;
439+ const py::ssize_t r = (i_ - q * si);
440+ mi[dim] = r;
441+ i_ = q;
442+ }
443+ if (nd) {
444+ mi[nd - 1 ] = i_;
445+ }
446+ return mi;
447+ }
448+
319449} // namespace py_internal
320450} // namespace tensor
321451} // namespace dpctl
0 commit comments