Skip to content

Commit 44847b7

Browse files
authored
Merge pull request #524 from remysucre/master
Intersect skewed sparse iterators by galloping
2 parents 4c24193 + 7a05d63 commit 44847b7

15 files changed

+436
-83
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,23 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
634634
/// reorder takes a new ordering for a set of index variables that are directly nested in the iteration order
635635
IndexStmt reorder(std::vector<IndexVar> reorderedvars) const;
636636

637+
/// The mergeby transformation specifies how to merge iterators on
638+
/// the given index variable. By default, if an iterator is used for windowing
639+
/// it will be merged with the "gallop" strategy.
640+
/// All other iterators are merged with the "two finger" strategy.
641+
/// The two finger strategy merges by advancing each iterator one at a time,
642+
/// while the gallop strategy implements the exponential search algorithm.
643+
///
644+
/// Preconditions:
645+
/// This command applies to variables involving sparse iterators only;
646+
/// it is a no-op if the variable invovles any dense iterators.
647+
/// Any variable can be merged with the two finger strategy, whereas gallop
648+
/// only applies to a variable if its merge lattice has a single point
649+
/// (i.e. an intersection). For example, if a variable involves multiplications
650+
/// only, it can be merged with gallop.
651+
/// Furthermore, all iterators must be ordered for gallop to apply.
652+
IndexStmt mergeby(IndexVar i, MergeStrategy strategy) const;
653+
637654
/// The parallelize
638655
/// transformation tags an index variable for parallel execution. The
639656
/// transformation takes as an argument the type of parallel hardware
@@ -829,13 +846,14 @@ class Forall : public IndexStmt {
829846
Forall() = default;
830847
Forall(const ForallNode*);
831848
Forall(IndexVar indexVar, IndexStmt stmt);
832-
Forall(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
849+
Forall(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
833850

834851
IndexVar getIndexVar() const;
835852
IndexStmt getStmt() const;
836853

837854
ParallelUnit getParallelUnit() const;
838855
OutputRaceStrategy getOutputRaceStrategy() const;
856+
MergeStrategy getMergeStrategy() const;
839857

840858
size_t getUnrollFactor() const;
841859

@@ -844,7 +862,7 @@ class Forall : public IndexStmt {
844862

845863
/// Create a forall index statement.
846864
Forall forall(IndexVar i, IndexStmt stmt);
847-
Forall forall(IndexVar i, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
865+
Forall forall(IndexVar i, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
848866

849867

850868
/// A where statment has a producer statement that binds a tensor variable in

include/taco/index_notation/index_notation_nodes.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,15 +398,16 @@ struct YieldNode : public IndexStmtNode {
398398
};
399399

400400
struct ForallNode : public IndexStmtNode {
401-
ForallNode(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0)
402-
: indexVar(indexVar), stmt(stmt), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {}
401+
ForallNode(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0)
402+
: indexVar(indexVar), stmt(stmt), merge_strategy(merge_strategy), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {}
403403

404404
void accept(IndexStmtVisitorStrict* v) const {
405405
v->visit(this);
406406
}
407407

408408
IndexVar indexVar;
409409
IndexStmt stmt;
410+
MergeStrategy merge_strategy;
410411
ParallelUnit parallel_unit;
411412
OutputRaceStrategy output_race_strategy;
412413
size_t unrollFactor = 0;

include/taco/index_notation/transformations.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class AddSuchThatPredicates;
2222
class Parallelize;
2323
class TopoReorder;
2424
class SetAssembleStrategy;
25+
class SetMergeStrategy;
2526

2627
/// A transformation is an optimization that transforms a statement in the
2728
/// concrete index notation into a new statement that computes the same result
@@ -36,6 +37,7 @@ class Transformation {
3637
Transformation(TopoReorder);
3738
Transformation(AddSuchThatPredicates);
3839
Transformation(SetAssembleStrategy);
40+
Transformation(SetMergeStrategy);
3941

4042
IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;
4143

@@ -206,6 +208,25 @@ class SetAssembleStrategy : public TransformationInterface {
206208
/// Print a SetAssembleStrategy command.
207209
std::ostream &operator<<(std::ostream &, const SetAssembleStrategy&);
208210

211+
class SetMergeStrategy : public TransformationInterface {
212+
public:
213+
SetMergeStrategy(IndexVar i, MergeStrategy strategy);
214+
215+
IndexVar geti() const;
216+
MergeStrategy getMergeStrategy() const;
217+
218+
IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;
219+
220+
void print(std::ostream &os) const;
221+
222+
private:
223+
struct Content;
224+
std::shared_ptr<Content> content;
225+
};
226+
227+
/// Print a SetMergeStrategy command.
228+
std::ostream &operator<<(std::ostream &, const SetMergeStrategy&);
229+
209230
// Autoscheduling functions
210231

211232
/**

include/taco/ir_tags.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ enum class AssembleStrategy {
3333
};
3434
extern const char *AssembleStrategy_NAMES[];
3535

36+
/// MergeStrategy::TwoFinger merges iterators by incrementing one at a time
37+
/// MergeStrategy::Galloping merges iterators by exponential search (galloping)
38+
enum class MergeStrategy {
39+
TwoFinger, Gallop
40+
};
41+
extern const char *MergeStrategy_NAMES[];
42+
3643
}
3744

3845
#endif //TACO_IR_TAGS_H

include/taco/lower/lowerer_impl_imperative.h

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,18 @@ class LowererImplImperative : public LowererImpl {
146146
* \param statement
147147
* A concrete index notation statement to compute at the points in the
148148
* sparse iteration space described by the merge lattice.
149+
* \param mergeStrategy
150+
* A strategy for merging iterators. One of TwoFinger or Gallop.
149151
*
150152
* \return
151153
* IR code to compute the forall loop.
152154
*/
153155
virtual ir::Stmt lowerMergeLattice(MergeLattice lattice, IndexVar coordinateVar,
154156
IndexStmt statement,
155-
const std::set<Access>& reducedAccesses);
157+
const std::set<Access>& reducedAccesses,
158+
MergeStrategy mergeStrategy);
156159

157-
virtual ir::Stmt resolveCoordinate(std::vector<Iterator> mergers, ir::Expr coordinate, bool emitVarDecl);
160+
virtual ir::Stmt resolveCoordinate(std::vector<Iterator> mergers, ir::Expr coordinate, bool emitVarDecl, bool mergeWithMax);
158161

159162
/**
160163
* Lower the merge point at the top of the given lattice to code that iterates
@@ -169,23 +172,29 @@ class LowererImplImperative : public LowererImpl {
169172
* coordinate the merge point is at.
170173
* A concrete index notation statement to compute at the points in the
171174
* sparse iteration space region described by the merge point.
175+
* \param mergeWithMax
176+
* A boolean indicating whether coordinates should be combined with MAX instead of MIN.
177+
* MAX is needed when the iterators are merged with the Gallop strategy.
172178
*/
173179
virtual ir::Stmt lowerMergePoint(MergeLattice pointLattice,
174180
ir::Expr coordinate, IndexVar coordinateVar, IndexStmt statement,
175-
const std::set<Access>& reducedAccesses, bool resolvedCoordDeclared);
181+
const std::set<Access>& reducedAccesses, bool resolvedCoordDeclared,
182+
MergeStrategy mergestrategy);
176183

177184
/// Lower a merge lattice to cases.
178185
virtual ir::Stmt lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt,
179186
MergeLattice lattice,
180-
const std::set<Access>& reducedAccesses);
187+
const std::set<Access>& reducedAccesses,
188+
MergeStrategy mergeStrategy);
181189

182190
/// Lower a forall loop body.
183191
virtual ir::Stmt lowerForallBody(ir::Expr coordinate, IndexStmt stmt,
184192
std::vector<Iterator> locaters,
185193
std::vector<Iterator> inserters,
186194
std::vector<Iterator> appenders,
187195
MergeLattice caseLattice,
188-
const std::set<Access>& reducedAccesses);
196+
const std::set<Access>& reducedAccesses,
197+
MergeStrategy mergeStrategy);
189198

190199

191200
/// Lower a where statement.
@@ -375,7 +384,7 @@ class LowererImplImperative : public LowererImpl {
375384

376385
/// Conditionally increment iterator position variables.
377386
ir::Stmt codeToIncIteratorVars(ir::Expr coordinate, IndexVar coordinateVar,
378-
std::vector<Iterator> iterators, std::vector<Iterator> mergers);
387+
std::vector<Iterator> iterators, std::vector<Iterator> mergers, MergeStrategy strategy);
379388

380389
ir::Stmt codeToLoadCoordinatesFromPosIterators(std::vector<Iterator> iterators, bool declVars);
381390

@@ -410,7 +419,8 @@ class LowererImplImperative : public LowererImpl {
410419
/// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt.
411420
/// Will emit checks for explicit zeros for each mode iterator and each locator in the lattice.
412421
ir::Stmt lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt,
413-
MergeLattice lattice, const std::set<Access>& reducedAccesses);
422+
MergeLattice lattice, const std::set<Access>& reducedAccesses,
423+
MergeStrategy mergeStrategy);
414424

415425
/// Constructs cases comparing the coordVar for each iterator to the resolved coordinate.
416426
/// Returns a vector where coordComparisons[i] corresponds to a case for iters[i]
@@ -444,7 +454,7 @@ class LowererImplImperative : public LowererImpl {
444454
/// The map must be of iterators to exprs of boolean types
445455
std::vector<ir::Stmt> lowerCasesFromMap(std::map<Iterator, ir::Expr> iteratorToCondition,
446456
ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice,
447-
const std::set<Access>& reducedAccesses);
457+
const std::set<Access>& reducedAccesses, MergeStrategy mergeStrategy);
448458

449459
/// Constructs an expression which checks if this access is "zero"
450460
ir::Expr constructCheckForAccessZero(Access);

src/codegen/codegen_c.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,28 @@ const string cHeaders =
6262
"int cmp(const void *a, const void *b) {\n"
6363
" return *((const int*)a) - *((const int*)b);\n"
6464
"}\n"
65+
// Increment arrayStart until array[arrayStart] >= target or arrayStart >= arrayEnd
66+
// using an exponential search algorithm: https://en.wikipedia.org/wiki/Exponential_search.
67+
"int taco_gallop(int *array, int arrayStart, int arrayEnd, int target) {\n"
68+
" if (array[arrayStart] >= target || arrayStart >= arrayEnd) {\n"
69+
" return arrayStart;\n"
70+
" }\n"
71+
" int step = 1;\n"
72+
" int curr = arrayStart;\n"
73+
" while (curr + step < arrayEnd && array[curr + step] < target) {\n"
74+
" curr += step;\n"
75+
" step = step * 2;\n"
76+
" }\n"
77+
"\n"
78+
" step = step / 2;\n"
79+
" while (step > 0) {\n"
80+
" if (curr + step < arrayEnd && array[curr + step] < target) {\n"
81+
" curr += step;\n"
82+
" }\n"
83+
" step = step / 2;\n"
84+
" }\n"
85+
" return curr+1;\n"
86+
"}\n"
6587
"int taco_binarySearchAfter(int *array, int arrayStart, int arrayEnd, int target) {\n"
6688
" if (array[arrayStart] >= target) {\n"
6789
" return arrayStart;\n"

src/index_notation/index_notation.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,6 +1907,15 @@ IndexStmt IndexStmt::reorder(std::vector<IndexVar> reorderedvars) const {
19071907
return transformed;
19081908
}
19091909

1910+
IndexStmt IndexStmt::mergeby(IndexVar i, MergeStrategy strategy) const {
1911+
string reason;
1912+
IndexStmt transformed = SetMergeStrategy(i, strategy).apply(*this, &reason);
1913+
if (!transformed.defined()) {
1914+
taco_uerror << reason;
1915+
}
1916+
return transformed;
1917+
}
1918+
19101919
IndexStmt IndexStmt::parallelize(IndexVar i, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy) const {
19111920
string reason;
19121921
IndexStmt transformed = Parallelize(i, parallel_unit, output_race_strategy).apply(*this, &reason);
@@ -2017,7 +2026,7 @@ IndexStmt IndexStmt::unroll(IndexVar i, size_t unrollFactor) const {
20172026

20182027
void visit(const ForallNode* node) {
20192028
if (node->indexVar == i) {
2020-
stmt = Forall(i, rewrite(node->stmt), node->parallel_unit, node->output_race_strategy, unrollFactor);
2029+
stmt = Forall(i, rewrite(node->stmt), node->merge_strategy, node->parallel_unit, node->output_race_strategy, unrollFactor);
20212030
}
20222031
else {
20232032
IndexNotationRewriter::visit(node);
@@ -2125,11 +2134,11 @@ Forall::Forall(const ForallNode* n) : IndexStmt(n) {
21252134
}
21262135

21272136
Forall::Forall(IndexVar indexVar, IndexStmt stmt)
2128-
: Forall(indexVar, stmt, ParallelUnit::NotParallel, OutputRaceStrategy::IgnoreRaces) {
2137+
: Forall(indexVar, stmt, MergeStrategy::TwoFinger, ParallelUnit::NotParallel, OutputRaceStrategy::IgnoreRaces) {
21292138
}
21302139

2131-
Forall::Forall(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor)
2132-
: Forall(new ForallNode(indexVar, stmt, parallel_unit, output_race_strategy, unrollFactor)) {
2140+
Forall::Forall(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor)
2141+
: Forall(new ForallNode(indexVar, stmt, merge_strategy, parallel_unit, output_race_strategy, unrollFactor)) {
21332142
}
21342143

21352144
IndexVar Forall::getIndexVar() const {
@@ -2148,6 +2157,10 @@ OutputRaceStrategy Forall::getOutputRaceStrategy() const {
21482157
return getNode(*this)->output_race_strategy;
21492158
}
21502159

2160+
MergeStrategy Forall::getMergeStrategy() const {
2161+
return getNode(*this)->merge_strategy;
2162+
}
2163+
21512164
size_t Forall::getUnrollFactor() const {
21522165
return getNode(*this)->unrollFactor;
21532166
}
@@ -2156,8 +2169,8 @@ Forall forall(IndexVar i, IndexStmt stmt) {
21562169
return Forall(i, stmt);
21572170
}
21582171

2159-
Forall forall(IndexVar i, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) {
2160-
return Forall(i, stmt, parallel_unit, output_race_strategy, unrollFactor);
2172+
Forall forall(IndexVar i, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) {
2173+
return Forall(i, stmt, merge_strategy, parallel_unit, output_race_strategy, unrollFactor);
21612174
}
21622175

21632176
template <> bool isa<Forall>(IndexStmt s) {
@@ -3938,7 +3951,7 @@ struct Zero : public IndexNotationRewriterStrict {
39383951
stmt = op;
39393952
}
39403953
else {
3941-
stmt = new ForallNode(op->indexVar, body, op->parallel_unit, op->output_race_strategy, op->unrollFactor);
3954+
stmt = new ForallNode(op->indexVar, body, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor);
39423955
}
39433956
}
39443957

src/index_notation/index_notation_rewriter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ void IndexNotationRewriter::visit(const ForallNode* op) {
185185
stmt = op;
186186
}
187187
else {
188-
stmt = new ForallNode(op->indexVar, s, op->parallel_unit, op->output_race_strategy, op->unrollFactor);
188+
stmt = new ForallNode(op->indexVar, s, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor);
189189
}
190190
}
191191

@@ -406,7 +406,7 @@ struct ReplaceIndexVars : public IndexNotationRewriter {
406406
stmt = op;
407407
}
408408
else {
409-
stmt = new ForallNode(iv, s, op->parallel_unit, op->output_race_strategy,
409+
stmt = new ForallNode(iv, s, op->merge_strategy, op->parallel_unit, op->output_race_strategy,
410410
op->unrollFactor);
411411
}
412412
}

0 commit comments

Comments
 (0)