1616#include " mlir/Dialect/Arith/IR/Arith.h"
1717#include " mlir/Dialect/Func/IR/FuncOps.h"
1818#include " mlir/Dialect/Linalg/IR/Linalg.h"
19+ #include " mlir/Dialect/Math/IR/Math.h"
1920#include " mlir/Dialect/MemRef/IR/MemRef.h"
2021#include " mlir/Dialect/SCF/IR/SCF.h"
2122#include " mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -38,12 +39,13 @@ static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
3839static constexpr const char kPartitionFuncNamePrefix [] = " _sparse_partition_" ;
3940static constexpr const char kBinarySearchFuncNamePrefix [] =
4041 " _sparse_binary_search_" ;
41- static constexpr const char kSortNonstableFuncNamePrefix [] =
42- " _sparse_sort_nonstable_ " ;
42+ static constexpr const char kHybridQuickSortFuncNamePrefix [] =
43+ " _sparse_hybrid_qsort_ " ;
4344static constexpr const char kSortStableFuncNamePrefix [] =
4445 " _sparse_sort_stable_" ;
4546static constexpr const char kShiftDownFuncNamePrefix [] = " _sparse_shift_down_" ;
4647static constexpr const char kHeapSortFuncNamePrefix [] = " _sparse_heap_sort_" ;
48+ static constexpr const char kQuickSortFuncNamePrefix [] = " _sparse_qsort_" ;
4749
4850using FuncGeneratorType = function_ref<void (
4951 OpBuilder &, ModuleOp, func::FuncOp, uint64_t , uint64_t , bool , uint32_t )>;
@@ -916,41 +918,19 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
916918 builder.create <func::ReturnOp>(loc);
917919}
918920
919- // / Creates a function to perform quick sort on the value in the range of
920- // / index [lo, hi).
921- //
922- // The generate IR corresponds to this C like algorithm:
923- // void quickSort(lo, hi, data) {
924- // if (lo < hi) {
925- // p = partition(low, high, data);
926- // quickSort(lo, p, data);
927- // quickSort(p + 1, hi, data);
928- // }
929- // }
930- static void createSortNonstableFunc (OpBuilder &builder, ModuleOp module ,
931- func::FuncOp func, uint64_t nx, uint64_t ny,
932- bool isCoo, uint32_t nTrailingP) {
933- (void )nTrailingP;
934- OpBuilder::InsertionGuard insertionGuard (builder);
935- Block *entryBlock = func.addEntryBlock ();
936- builder.setInsertionPointToStart (entryBlock);
937-
921+ static void createQuickSort (OpBuilder &builder, ModuleOp module ,
922+ func::FuncOp func, ValueRange args, uint64_t nx,
923+ uint64_t ny, bool isCoo, uint32_t nTrailingP) {
938924 MLIRContext *context = module .getContext ();
939925 Location loc = func.getLoc ();
940- ValueRange args = entryBlock->getArguments ();
941926 Value lo = args[loIdx];
942927 Value hi = args[hiIdx];
943- Value cond =
944- builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
945- scf::IfOp ifOp = builder.create <scf::IfOp>(loc, cond, /* else=*/ false );
946-
947- // The if-stmt true branch.
948- builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
949928 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc (
950929 builder, func, {IndexType::get (context)}, kPartitionFuncNamePrefix , nx,
951- ny, isCoo, args, createPartitionFunc);
952- auto p = builder.create <func::CallOp>(
953- loc, partitionFunc, TypeRange{IndexType::get (context)}, ValueRange (args));
930+ ny, isCoo, args.drop_back (nTrailingP), createPartitionFunc);
931+ auto p = builder.create <func::CallOp>(loc, partitionFunc,
932+ TypeRange{IndexType::get (context)},
933+ args.drop_back (nTrailingP));
954934
955935 SmallVector<Value> lowOperands{lo, p.getResult (0 )};
956936 lowOperands.append (args.begin () + xStartIdx, args.end ());
@@ -962,10 +942,6 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
962942 hi};
963943 highOperands.append (args.begin () + xStartIdx, args.end ());
964944 builder.create <func::CallOp>(loc, func, highOperands);
965-
966- // After the if-stmt.
967- builder.setInsertionPointAfter (ifOp);
968- builder.create <func::ReturnOp>(loc);
969945}
970946
971947// / Creates a function to perform insertion sort on the values in the range of
@@ -1054,6 +1030,116 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
10541030 builder.create <func::ReturnOp>(loc);
10551031}
10561032
1033+ // / Creates a function to perform quick sort or a hybrid quick sort on the
1034+ // / values in the range of index [lo, hi).
1035+ //
1036+ //
1037+ // When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1038+ // void quickSort(lo, hi, data) {
1039+ // if (lo + 1 < hi) {
1040+ // p = partition(low, high, data);
1041+ // quickSort(lo, p, data);
1042+ // quickSort(p + 1, hi, data);
1043+ // }
1044+ // }
1045+ //
1046+ // When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1047+ // void hybridQuickSort(lo, hi, data, depthLimit) {
1048+ // if (lo + 1 < hi) {
1049+ // len = hi - lo;
1050+ // if (len <= limit) {
1051+ // insertionSort(lo, hi, data);
1052+ // } else {
1053+ // depthLimit --;
1054+ // if (depthLimit <= 0) {
1055+ // heapSort(lo, hi, data);
1056+ // } else {
1057+ // p = partition(low, high, data);
1058+ // quickSort(lo, p, data);
1059+ // quickSort(p + 1, hi, data);
1060+ // }
1061+ // depthLimit ++;
1062+ // }
1063+ // }
1064+ // }
1065+ //
1066+ static void createQuickSortFunc (OpBuilder &builder, ModuleOp module ,
1067+ func::FuncOp func, uint64_t nx, uint64_t ny,
1068+ bool isCoo, uint32_t nTrailingP) {
1069+ assert (nTrailingP == 1 || nTrailingP == 0 );
1070+ bool isHybrid = (nTrailingP == 1 );
1071+ OpBuilder::InsertionGuard insertionGuard (builder);
1072+ Block *entryBlock = func.addEntryBlock ();
1073+ builder.setInsertionPointToStart (entryBlock);
1074+
1075+ Location loc = func.getLoc ();
1076+ ValueRange args = entryBlock->getArguments ();
1077+ Value lo = args[loIdx];
1078+ Value hi = args[hiIdx];
1079+ Value loCmp =
1080+ builder.create <arith::AddIOp>(loc, lo, constantIndex (builder, loc, 1 ));
1081+ Value cond =
1082+ builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loCmp, hi);
1083+ scf::IfOp ifOp = builder.create <scf::IfOp>(loc, cond, /* else=*/ false );
1084+
1085+ // The if-stmt true branch.
1086+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
1087+ Value pDepthLimit;
1088+ Value savedDepthLimit;
1089+ scf::IfOp depthIf;
1090+
1091+ if (isHybrid) {
1092+ Value len = builder.create <arith::SubIOp>(loc, hi, lo);
1093+ Value lenLimit = constantIndex (builder, loc, 30 );
1094+ Value lenCond = builder.create <arith::CmpIOp>(
1095+ loc, arith::CmpIPredicate::ule, len, lenLimit);
1096+ scf::IfOp lenIf = builder.create <scf::IfOp>(loc, lenCond, /* else=*/ true );
1097+
1098+ // When len <= limit.
1099+ builder.setInsertionPointToStart (&lenIf.getThenRegion ().front ());
1100+ FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc (
1101+ builder, func, TypeRange (), kSortStableFuncNamePrefix , nx, ny, isCoo,
1102+ args.drop_back (nTrailingP), createSortStableFunc);
1103+ builder.create <func::CallOp>(loc, insertionSortFunc, TypeRange (),
1104+ ValueRange (args.drop_back (nTrailingP)));
1105+
1106+ // When len > limit.
1107+ builder.setInsertionPointToStart (&lenIf.getElseRegion ().front ());
1108+ pDepthLimit = args.back ();
1109+ savedDepthLimit = builder.create <memref::LoadOp>(loc, pDepthLimit);
1110+ Value depthLimit = builder.create <arith::SubIOp>(
1111+ loc, savedDepthLimit, constantI64 (builder, loc, 1 ));
1112+ builder.create <memref::StoreOp>(loc, depthLimit, pDepthLimit);
1113+ Value depthCond =
1114+ builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1115+ depthLimit, constantI64 (builder, loc, 0 ));
1116+ depthIf = builder.create <scf::IfOp>(loc, depthCond, /* else=*/ true );
1117+
1118+ // When depth exceeds limit.
1119+ builder.setInsertionPointToStart (&depthIf.getThenRegion ().front ());
1120+ FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc (
1121+ builder, func, TypeRange (), kHeapSortFuncNamePrefix , nx, ny, isCoo,
1122+ args.drop_back (nTrailingP), createHeapSortFunc);
1123+ builder.create <func::CallOp>(loc, heapSortFunc, TypeRange (),
1124+ ValueRange (args.drop_back (nTrailingP)));
1125+
1126+ // When depth doesn't exceed limit.
1127+ builder.setInsertionPointToStart (&depthIf.getElseRegion ().front ());
1128+ }
1129+
1130+ createQuickSort (builder, module , func, args, nx, ny, isCoo, nTrailingP);
1131+
1132+ if (isHybrid) {
1133+ // Restore depthLimit.
1134+ builder.setInsertionPointAfter (depthIf);
1135+ builder.create <memref::StoreOp>(loc, savedDepthLimit, pDepthLimit);
1136+ }
1137+
1138+ // After the if-stmt.
1139+ builder.setInsertionPointAfter (ifOp);
1140+ builder.create <func::ReturnOp>(loc);
1141+ }
1142+
10571143// / Implements the rewriting for operator sort and sort_coo.
10581144template <typename OpTy>
10591145LogicalResult matchAndRewriteSortOp (OpTy op, ValueRange xys, uint64_t nx,
@@ -1078,10 +1164,30 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
10781164 FuncGeneratorType funcGenerator;
10791165 uint32_t nTrailingP = 0 ;
10801166 switch (op.getAlgorithm ()) {
1081- case SparseTensorSortKind::HybridQuickSort:
1167+ case SparseTensorSortKind::HybridQuickSort: {
1168+ funcName = kHybridQuickSortFuncNamePrefix ;
1169+ funcGenerator = createQuickSortFunc;
1170+ nTrailingP = 1 ;
1171+ Value pDepthLimit = rewriter.create <memref::AllocaOp>(
1172+ loc, MemRefType::get ({}, rewriter.getI64Type ()));
1173+ operands.push_back (pDepthLimit);
1174+ // As a heuristics, set depthLimit = 2 * log2(n).
1175+ Value lo = operands[loIdx];
1176+ Value hi = operands[hiIdx];
1177+ Value len = rewriter.create <arith::IndexCastOp>(
1178+ loc, rewriter.getI64Type (),
1179+ rewriter.create <arith::SubIOp>(loc, hi, lo));
1180+ Value depthLimit = rewriter.create <arith::SubIOp>(
1181+ loc, constantI64 (rewriter, loc, 64 ),
1182+ rewriter.create <math::CountLeadingZerosOp>(loc, len));
1183+ depthLimit = rewriter.create <arith::ShLIOp>(loc, depthLimit,
1184+ constantI64 (rewriter, loc, 1 ));
1185+ rewriter.create <memref::StoreOp>(loc, depthLimit, pDepthLimit);
1186+ break ;
1187+ }
10821188 case SparseTensorSortKind::QuickSort:
1083- funcName = kSortNonstableFuncNamePrefix ;
1084- funcGenerator = createSortNonstableFunc ;
1189+ funcName = kQuickSortFuncNamePrefix ;
1190+ funcGenerator = createQuickSortFunc ;
10851191 break ;
10861192 case SparseTensorSortKind::InsertionSortStable:
10871193 funcName = kSortStableFuncNamePrefix ;
0 commit comments