@@ -623,7 +623,7 @@ int simplify_iteration_three_strides(const int nd,
623623 auto str3_p = strides3[p];
624624 shape_w.push_back (sh_p);
625625 if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 &&
626- std::min (std::min ( str1_p, str2_p) , str3_p) < 0 )
626+ std::min ({ str1_p, str2_p, str3_p} ) < 0 )
627627 {
628628 disp1 += str1_p * (sh_p - 1 );
629629 str1_p = -str1_p;
@@ -716,6 +716,198 @@ contract_iter3(vecT shape, vecT strides1, vecT strides2, vecT strides3)
716716 out_strides3, disp3);
717717}
718718
719+ /*
720+ For purposes of iterating over pairs of elements of four arrays
721+ with `shape` and strides `strides1`, `strides2`, `strides3`,
722+ `strides4` given as pointers `simplify_iteration_four_strides(nd,
723+ shape_ptr, strides1_ptr, strides2_ptr, strides3_ptr, strides4_ptr,
724+ disp1, disp2, disp3, disp4)` may modify memory and returns new
725+ length of these arrays.
726+
727+ The new shape and new strides, as well as the offset
728+ `(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3,
729+ new_stride4, disp4)` are such that iterating over them will traverse the
730+ same set of tuples of elements, possibly in a different order.
731+ */
732+ template <class ShapeTy , class StridesTy >
733+ int simplify_iteration_four_strides (const int nd,
734+ ShapeTy *shape,
735+ StridesTy *strides1,
736+ StridesTy *strides2,
737+ StridesTy *strides3,
738+ StridesTy *strides4,
739+ StridesTy &disp1,
740+ StridesTy &disp2,
741+ StridesTy &disp3,
742+ StridesTy &disp4)
743+ {
744+ disp1 = std::ptrdiff_t (0 );
745+ disp2 = std::ptrdiff_t (0 );
746+ if (nd < 2 )
747+ return nd;
748+
749+ std::vector<int > pos (nd);
750+ std::iota (pos.begin (), pos.end (), 0 );
751+
752+ std::stable_sort (
753+ pos.begin (), pos.end (),
754+ [&strides1, &strides2, &strides3, &strides4, &shape](int i1, int i2) {
755+ auto abs_str1_i1 =
756+ (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
757+ auto abs_str1_i2 =
758+ (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
759+ auto abs_str2_i1 =
760+ (strides2[i1] < 0 ) ? -strides2[i1] : strides2[i1];
761+ auto abs_str2_i2 =
762+ (strides2[i2] < 0 ) ? -strides2[i2] : strides2[i2];
763+ auto abs_str3_i1 =
764+ (strides3[i1] < 0 ) ? -strides3[i1] : strides3[i1];
765+ auto abs_str3_i2 =
766+ (strides3[i2] < 0 ) ? -strides3[i2] : strides3[i2];
767+ auto abs_str4_i1 =
768+ (strides4[i1] < 0 ) ? -strides4[i1] : strides4[i1];
769+ auto abs_str4_i2 =
770+ (strides4[i2] < 0 ) ? -strides4[i2] : strides4[i2];
771+ return (abs_str1_i1 > abs_str1_i2) ||
772+ ((abs_str1_i1 == abs_str1_i2) &&
773+ ((abs_str2_i1 > abs_str2_i2) ||
774+ ((abs_str2_i1 == abs_str2_i2) &&
775+ ((abs_str3_i1 > abs_str3_i2) ||
776+ ((abs_str3_i1 == abs_str3_i2) &&
777+ ((abs_str4_i1 > abs_str4_i2) ||
778+ ((abs_str4_i1 == abs_str4_i2) &&
779+ (shape[i1] > shape[i2]))))))));
780+ });
781+
782+ std::vector<ShapeTy> shape_w;
783+ std::vector<StridesTy> strides1_w;
784+ std::vector<StridesTy> strides2_w;
785+ std::vector<StridesTy> strides3_w;
786+ std::vector<StridesTy> strides4_w;
787+
788+ bool contractable = true ;
789+ for (int i = 0 ; i < nd; ++i) {
790+ auto p = pos[i];
791+ auto sh_p = shape[p];
792+ auto str1_p = strides1[p];
793+ auto str2_p = strides2[p];
794+ auto str3_p = strides3[p];
795+ auto str4_p = strides4[p];
796+ shape_w.push_back (sh_p);
797+ if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 && str4_p <= 0 &&
798+ std::min ({str1_p, str2_p, str3_p, str4_p}) < 0 )
799+ {
800+ disp1 += str1_p * (sh_p - 1 );
801+ str1_p = -str1_p;
802+ disp2 += str2_p * (sh_p - 1 );
803+ str2_p = -str2_p;
804+ disp3 += str3_p * (sh_p - 1 );
805+ str3_p = -str3_p;
806+ disp4 += str4_p * (sh_p - 1 );
807+ str4_p = -str4_p;
808+ }
809+ if (str1_p < 0 || str2_p < 0 || str3_p < 0 || str4_p < 0 ) {
810+ contractable = false ;
811+ }
812+ strides1_w.push_back (str1_p);
813+ strides2_w.push_back (str2_p);
814+ strides3_w.push_back (str3_p);
815+ strides4_w.push_back (str4_p);
816+ }
817+ int nd_ = nd;
818+ while (contractable) {
819+ bool changed = false ;
820+ for (int i = 0 ; i + 1 < nd_; ++i) {
821+ StridesTy str1 = strides1_w[i + 1 ];
822+ StridesTy str2 = strides2_w[i + 1 ];
823+ StridesTy str3 = strides3_w[i + 1 ];
824+ StridesTy str4 = strides4_w[i + 1 ];
825+ StridesTy jump1 = strides1_w[i] - (shape_w[i + 1 ] - 1 ) * str1;
826+ StridesTy jump2 = strides2_w[i] - (shape_w[i + 1 ] - 1 ) * str2;
827+ StridesTy jump3 = strides3_w[i] - (shape_w[i + 1 ] - 1 ) * str3;
828+ StridesTy jump4 = strides4_w[i] - (shape_w[i + 1 ] - 1 ) * str4;
829+
830+ if (jump1 == str1 && jump2 == str2 && jump3 == str3 &&
831+ jump4 == str4) {
832+ changed = true ;
833+ shape_w[i] *= shape_w[i + 1 ];
834+ for (int j = i; j < nd_; ++j) {
835+ strides1_w[j] = strides1_w[j + 1 ];
836+ }
837+ for (int j = i; j < nd_; ++j) {
838+ strides2_w[j] = strides2_w[j + 1 ];
839+ }
840+ for (int j = i; j < nd_; ++j) {
841+ strides3_w[j] = strides3_w[j + 1 ];
842+ }
843+ for (int j = i; j < nd_; ++j) {
844+ strides4_w[j] = strides4_w[j + 1 ];
845+ }
846+ for (int j = i + 1 ; j + 1 < nd_; ++j) {
847+ shape_w[j] = shape_w[j + 1 ];
848+ }
849+ --nd_;
850+ break ;
851+ }
852+ }
853+ if (!changed)
854+ break ;
855+ }
856+ for (int i = 0 ; i < nd_; ++i) {
857+ shape[i] = shape_w[i];
858+ }
859+ for (int i = 0 ; i < nd_; ++i) {
860+ strides1[i] = strides1_w[i];
861+ }
862+ for (int i = 0 ; i < nd_; ++i) {
863+ strides2[i] = strides2_w[i];
864+ }
865+ for (int i = 0 ; i < nd_; ++i) {
866+ strides3[i] = strides3_w[i];
867+ }
868+ for (int i = 0 ; i < nd_; ++i) {
869+ strides4[i] = strides4_w[i];
870+ }
871+
872+ return nd_;
873+ }
874+
875+ template <typename T, class Error , typename vecT = std::vector<T>>
876+ std::tuple<vecT, vecT, T, vecT, T, vecT, T, vecT, T>
877+ contract_iter4 (vecT shape,
878+ vecT strides1,
879+ vecT strides2,
880+ vecT strides3,
881+ vecT strides4)
882+ {
883+ const size_t dim = shape.size ();
884+ if (dim != strides1.size () || dim != strides2.size () ||
885+ dim != strides3.size () || dim != strides4.size ())
886+ {
887+ throw Error (" Shape and strides must be of equal size." );
888+ }
889+ vecT out_shape = shape;
890+ vecT out_strides1 = strides1;
891+ vecT out_strides2 = strides2;
892+ vecT out_strides3 = strides3;
893+ vecT out_strides4 = strides4;
894+ T disp1 (0 );
895+ T disp2 (0 );
896+ T disp3 (0 );
897+ T disp4 (0 );
898+
899+ int nd = simplify_iteration_four_strides (
900+ dim, out_shape.data (), out_strides1.data (), out_strides2.data (),
901+ out_strides3.data (), out_strides4.data (), disp1, disp2, disp3, disp4);
902+ out_shape.resize (nd);
903+ out_strides1.resize (nd);
904+ out_strides2.resize (nd);
905+ out_strides3.resize (nd);
906+ out_strides4.resize (nd);
907+ return std::make_tuple (out_shape, out_strides1, disp1, out_strides2, disp2,
908+ out_strides3, disp3, out_strides4, disp4);
909+ }
910+
719911} // namespace strides
720912} // namespace tensor
721913} // namespace dpctl
0 commit comments