@@ -623,7 +623,7 @@ int simplify_iteration_three_strides(const int nd,
623623 auto str3_p = strides3[p];
624624 shape_w.push_back (sh_p);
625625 if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 &&
626- std::min (std::min ( str1_p, str2_p) , str3_p) < 0 )
626+ std::min ({ str1_p, str2_p, str3_p} ) < 0 )
627627 {
628628 disp1 += str1_p * (sh_p - 1 );
629629 str1_p = -str1_p;
@@ -716,6 +716,162 @@ contract_iter3(vecT shape, vecT strides1, vecT strides2, vecT strides3)
716716 out_strides3, disp3);
717717}
718718
719+ /*
720+ For purposes of iterating over pairs of elements of four arrays
721+ with `shape` and strides `strides1`, `strides2`, `strides3`,
722+ `strides4` given as pointers `simplify_iteration_four_strides(nd,
723+ shape_ptr, strides1_ptr, strides2_ptr, strides3_ptr, strides4_ptr,
724+ disp1, disp2, disp3, disp4)` may modify memory and returns new
725+ length of these arrays.
726+
727+ The new shape and new strides, as well as the offset
728+ `(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3,
729+ new_stride4, disp4)` are such that iterating over them will traverse the
730+ same set of tuples of elements, possibly in a different order.
731+ */
732+ template <class ShapeTy , class StridesTy >
733+ int simplify_iteration_four_strides (const int nd,
734+ ShapeTy *shape,
735+ StridesTy *strides1,
736+ StridesTy *strides2,
737+ StridesTy *strides3,
738+ StridesTy *strides4,
739+ StridesTy &disp1,
740+ StridesTy &disp2,
741+ StridesTy &disp3,
742+ StridesTy &disp4)
743+ {
744+ disp1 = std::ptrdiff_t (0 );
745+ disp2 = std::ptrdiff_t (0 );
746+ if (nd < 2 )
747+ return nd;
748+
749+ std::vector<int > pos (nd);
750+ std::iota (pos.begin (), pos.end (), 0 );
751+
752+ std::stable_sort (
753+ pos.begin (), pos.end (),
754+ [&strides1, &strides2, &strides3, &strides4, &shape](int i1, int i2) {
755+ auto abs_str1_i1 =
756+ (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
757+ auto abs_str1_i2 =
758+ (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
759+ auto abs_str2_i1 =
760+ (strides2[i1] < 0 ) ? -strides2[i1] : strides2[i1];
761+ auto abs_str2_i2 =
762+ (strides2[i2] < 0 ) ? -strides2[i2] : strides2[i2];
763+ auto abs_str3_i1 =
764+ (strides3[i1] < 0 ) ? -strides3[i1] : strides3[i1];
765+ auto abs_str3_i2 =
766+ (strides3[i2] < 0 ) ? -strides3[i2] : strides3[i2];
767+ auto abs_str4_i1 =
768+ (strides4[i1] < 0 ) ? -strides4[i1] : strides4[i1];
769+ auto abs_str4_i2 =
770+ (strides4[i2] < 0 ) ? -strides4[i2] : strides4[i2];
771+ return (abs_str1_i1 > abs_str1_i2) ||
772+ ((abs_str1_i1 == abs_str1_i2) &&
773+ ((abs_str2_i1 > abs_str2_i2) ||
774+ ((abs_str2_i1 == abs_str2_i2) &&
775+ ((abs_str3_i1 > abs_str3_i2) ||
776+ ((abs_str3_i1 == abs_str3_i2) &&
777+ ((abs_str4_i1 > abs_str4_i2) ||
778+ ((abs_str4_i1 == abs_str4_i2) &&
779+ (shape[i1] > shape[i2]))))))));
780+ });
781+
782+ std::vector<ShapeTy> shape_w;
783+ std::vector<StridesTy> strides1_w;
784+ std::vector<StridesTy> strides2_w;
785+ std::vector<StridesTy> strides3_w;
786+ std::vector<StridesTy> strides4_w;
787+
788+ bool contractable = true ;
789+ for (int i = 0 ; i < nd; ++i) {
790+ auto p = pos[i];
791+ auto sh_p = shape[p];
792+ auto str1_p = strides1[p];
793+ auto str2_p = strides2[p];
794+ auto str3_p = strides3[p];
795+ auto str4_p = strides4[p];
796+ shape_w.push_back (sh_p);
797+ if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 && str4_p <= 0 &&
798+ std::min ({str1_p, str2_p, str3_p, str4_p}) < 0 )
799+ {
800+ disp1 += str1_p * (sh_p - 1 );
801+ str1_p = -str1_p;
802+ disp2 += str2_p * (sh_p - 1 );
803+ str2_p = -str2_p;
804+ disp3 += str3_p * (sh_p - 1 );
805+ str3_p = -str3_p;
806+ disp4 += str4_p * (sh_p - 1 );
807+ str4_p = -str4_p;
808+ }
809+ if (str1_p < 0 || str2_p < 0 || str3_p < 0 || str4_p < 0 ) {
810+ contractable = false ;
811+ }
812+ strides1_w.push_back (str1_p);
813+ strides2_w.push_back (str2_p);
814+ strides3_w.push_back (str3_p);
815+ strides4_w.push_back (str4_p);
816+ }
817+ int nd_ = nd;
818+ while (contractable) {
819+ bool changed = false ;
820+ for (int i = 0 ; i + 1 < nd_; ++i) {
821+ StridesTy str1 = strides1_w[i + 1 ];
822+ StridesTy str2 = strides2_w[i + 1 ];
823+ StridesTy str3 = strides3_w[i + 1 ];
824+ StridesTy str4 = strides4_w[i + 1 ];
825+ StridesTy jump1 = strides1_w[i] - (shape_w[i + 1 ] - 1 ) * str1;
826+ StridesTy jump2 = strides2_w[i] - (shape_w[i + 1 ] - 1 ) * str2;
827+ StridesTy jump3 = strides3_w[i] - (shape_w[i + 1 ] - 1 ) * str3;
828+ StridesTy jump4 = strides4_w[i] - (shape_w[i + 1 ] - 1 ) * str4;
829+
830+ if (jump1 == str1 && jump2 == str2 && jump3 == str3 &&
831+ jump4 == str4) {
832+ changed = true ;
833+ shape_w[i] *= shape_w[i + 1 ];
834+ for (int j = i; j < nd_; ++j) {
835+ strides1_w[j] = strides1_w[j + 1 ];
836+ }
837+ for (int j = i; j < nd_; ++j) {
838+ strides2_w[j] = strides2_w[j + 1 ];
839+ }
840+ for (int j = i; j < nd_; ++j) {
841+ strides3_w[j] = strides3_w[j + 1 ];
842+ }
843+ for (int j = i; j < nd_; ++j) {
844+ strides4_w[j] = strides4_w[j + 1 ];
845+ }
846+ for (int j = i + 1 ; j + 1 < nd_; ++j) {
847+ shape_w[j] = shape_w[j + 1 ];
848+ }
849+ --nd_;
850+ break ;
851+ }
852+ }
853+ if (!changed)
854+ break ;
855+ }
856+ for (int i = 0 ; i < nd_; ++i) {
857+ shape[i] = shape_w[i];
858+ }
859+ for (int i = 0 ; i < nd_; ++i) {
860+ strides1[i] = strides1_w[i];
861+ }
862+ for (int i = 0 ; i < nd_; ++i) {
863+ strides2[i] = strides2_w[i];
864+ }
865+ for (int i = 0 ; i < nd_; ++i) {
866+ strides3[i] = strides3_w[i];
867+ }
868+ for (int i = 0 ; i < nd_; ++i) {
869+ strides4[i] = strides4_w[i];
870+ }
871+
872+ return nd_;
873+ }
874+
719875} // namespace strides
720876} // namespace tensor
721877} // namespace dpctl
0 commit comments