@@ -141,19 +141,22 @@ void simplify_iteration_space(int &nd,
141141 assert (simplified_shape.size () == static_cast <size_t >(nd));
142142
143143 simplified_src_strides.reserve (nd);
144- simplified_src_strides.push_back (
145- (src_strides[0 ] >= 0 ) ? src_strides[0 ] : -src_strides[0 ]);
146- if ((src_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
147- src_offset += (shape[0 ] - 1 ) * src_strides[0 ];
148- }
149- assert (simplified_src_strides.size () == static_cast <size_t >(nd));
150-
151144 simplified_dst_strides.reserve (nd);
152- simplified_dst_strides.push_back (
153- (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
154- if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
155- dst_offset += (shape[0 ] - 1 ) * dst_strides[0 ];
145+
146+ if (src_strides[0 ] < 0 && dst_strides[0 ] < 0 ) {
147+ simplified_src_strides.push_back (-src_strides[0 ]);
148+ simplified_dst_strides.push_back (-dst_strides[0 ]);
149+ if (shape[0 ] > 1 ) {
150+ src_offset += (shape[0 ] - 1 ) * src_strides[0 ];
151+ dst_offset += (shape[0 ] - 1 ) * dst_strides[0 ];
152+ }
153+ }
154+ else {
155+ simplified_src_strides.push_back (src_strides[0 ]);
156+ simplified_dst_strides.push_back (dst_strides[0 ]);
156157 }
158+
159+ assert (simplified_src_strides.size () == static_cast <size_t >(nd));
157160 assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
158161 }
159162}
@@ -226,27 +229,28 @@ void simplify_iteration_space_3(
226229 assert (simplified_shape.size () == static_cast <size_t >(nd));
227230
228231 simplified_src1_strides.reserve (nd);
229- simplified_src1_strides.push_back (
230- (src1_strides[0 ] >= 0 ) ? src1_strides[0 ] : -src1_strides[0 ]);
231- if ((src1_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
232- src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
233- }
234- assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
235-
236232 simplified_src2_strides.reserve (nd);
237- simplified_src2_strides.push_back (
238- (src2_strides[0 ] >= 0 ) ? src2_strides[0 ] : -src2_strides[0 ]);
239- if ((src2_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
240- src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
241- }
242- assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
243-
244233 simplified_dst_strides.reserve (nd);
245- simplified_dst_strides.push_back (
246- (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
247- if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
248- dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
234+
235+ if ((src1_strides[0 ] < 0 ) && (src2_strides[0 ] < 0 ) &&
236+ (dst_strides[0 ] < 0 )) {
237+ simplified_src1_strides.push_back (-src1_strides[0 ]);
238+ simplified_src2_strides.push_back (-src2_strides[0 ]);
239+ simplified_dst_strides.push_back (-dst_strides[0 ]);
240+ if (shape[0 ] > 1 ) {
241+ src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
242+ src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
243+ dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
244+ }
245+ }
246+ else {
247+ simplified_src1_strides.push_back (src1_strides[0 ]);
248+ simplified_src2_strides.push_back (src2_strides[0 ]);
249+ simplified_dst_strides.push_back (dst_strides[0 ]);
249250 }
251+
252+ assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
253+ assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
250254 assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
251255 }
252256}
@@ -333,35 +337,34 @@ void simplify_iteration_space_4(
333337 assert (simplified_shape.size () == static_cast <size_t >(nd));
334338
335339 simplified_src1_strides.reserve (nd);
336- simplified_src1_strides.push_back (
337- (src1_strides[0 ] >= 0 ) ? src1_strides[0 ] : -src1_strides[0 ]);
338- if ((src1_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
339- src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
340- }
341- assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
342-
343340 simplified_src2_strides.reserve (nd);
344- simplified_src2_strides.push_back (
345- (src2_strides[0 ] >= 0 ) ? src2_strides[0 ] : -src2_strides[0 ]);
346- if ((src2_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
347- src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
348- }
349- assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
350-
351341 simplified_src3_strides.reserve (nd);
352- simplified_src3_strides.push_back (
353- (src3_strides[0 ] >= 0 ) ? src3_strides[0 ] : -src3_strides[0 ]);
354- if ((src3_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
355- src3_offset += src3_strides[0 ] * (shape[0 ] - 1 );
356- }
357- assert (simplified_src3_strides.size () == static_cast <size_t >(nd));
358-
359342 simplified_dst_strides.reserve (nd);
360- simplified_dst_strides.push_back (
361- (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
362- if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
363- dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
343+
344+ if ((src1_strides[0 ] < 0 ) && (src2_strides[0 ] < 0 ) &&
345+ (src3_strides[0 ] < 0 ) && (dst_strides[0 ] < 0 ))
346+ {
347+ simplified_src1_strides.push_back (-src1_strides[0 ]);
348+ simplified_src2_strides.push_back (-src2_strides[0 ]);
349+ simplified_src3_strides.push_back (-src3_strides[0 ]);
350+ simplified_dst_strides.push_back (-dst_strides[0 ]);
351+ if (shape[0 ] > 1 ) {
352+ src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
353+ src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
354+ src3_offset += src3_strides[0 ] * (shape[0 ] - 1 );
355+ dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
356+ }
357+ }
358+ else {
359+ simplified_src1_strides.push_back (src1_strides[0 ]);
360+ simplified_src2_strides.push_back (src2_strides[0 ]);
361+ simplified_src3_strides.push_back (src3_strides[0 ]);
362+ simplified_dst_strides.push_back (dst_strides[0 ]);
364363 }
364+
365+ assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
366+ assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
367+ assert (simplified_src3_strides.size () == static_cast <size_t >(nd));
365368 assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
366369 }
367370}
0 commit comments