Skip to content

Commit 4738f56

Browse files
committed
Fix precopute multidim indexvar renaming and add in multidim indexvars to parser
1 parent e1e0506 commit 4738f56

File tree

6 files changed

+199
-51
lines changed

6 files changed

+199
-51
lines changed

include/taco/parser/schedule_parser.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@ namespace parser {
1212
// [ [ "reorder", "i", "j" ], [ "precompute", "D(i,j)*E(j,k)", "j", "j_pre" ] ]
1313
std::vector<std::vector<std::string>> ScheduleParser(const std::string);
1414

15+
std::vector<std::string> varListParser(const std::string);
16+
1517
// serialize the result of a parse (for debugging)
1618
std::string serializeParsedSchedule(std::vector<std::vector<std::string>>);
1719

20+
21+
1822
}}
1923

2024
#endif //TACO_EINSUM_PARSER_H

src/ir/workspace_rewriter.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@ struct WorkspaceRewriter : ir::IRRewriter {
2929
string gpName = temp.getName() + to_string(op->mode + 1) + "_dimension";
3030

3131
if (temp.defined() && gpName == op->name) {
32-
taco_iassert(temporarySizeMap.find(temp) != temporarySizeMap.end()) << "Cannot rewrite workspace Dimension GetProperty "
33-
"due to tensorVar not in expression map";
32+
taco_iassert(temporarySizeMap.find(temp) != temporarySizeMap.end()) << "Cannot rewrite workspace Dimension "
33+
"GetProperty due to tensorVar not in expression map";
3434
auto tempExprList = temporarySizeMap.at(temp);
3535

36-
taco_iassert((int)tempExprList.size() > op->mode) << "Cannot rewrite workspace Dimension GetProperty "
37-
"due to mode not in expresison map ";
36+
taco_iassert((int)tempExprList.size() > op->mode) << "Cannot rewrite workspace (" << op->tensor <<
37+
") Dimension GetProperty "
38+
"due to mode (" << op->mode << ") not in expression map (size = "
39+
<< tempExprList.size() << ")";
3840
expr = tempExprList.at(op->mode);
3941
return;
4042
}

src/lower/lowerer_impl_imperative.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,7 +1800,8 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt,
18001800

18011801
Expr LowererImplImperative::getTemporarySize(Where where) {
18021802
TensorVar temporary = where.getTemporary();
1803-
Dimension temporarySize = temporary.getType().getShape().getDimension(0);
1803+
int temporaryOrder = temporary.getType().getShape().getOrder();
1804+
18041805
Access temporaryAccess = getResultAccesses(where.getProducer()).first[0];
18051806
std::vector<IndexVar> indexVars = temporaryAccess.getIndexVars();
18061807

@@ -1820,20 +1821,28 @@ Expr LowererImplImperative::getTemporarySize(Where where) {
18201821
return size;
18211822
}
18221823

1823-
if (temporarySize.isFixed()) {
1824-
auto size = ir::Literal::make(temporarySize.getSize());
1825-
temporarySizeMap[temporary] = {size};
1826-
return size;
1827-
}
1824+
vector<Expr> sizeVector;
1825+
Expr finalSize;
1826+
for (int i = 0; i < temporaryOrder; i++) {
1827+
Dimension temporarySize = temporary.getType().getShape().getDimension(i);
1828+
Expr size;
1829+
if (temporarySize.isFixed()) {
1830+
size = ir::Literal::make(temporarySize.getSize());
18281831

1829-
if (temporarySize.isIndexVarSized()) {
1830-
IndexVar var = temporarySize.getIndexVarSize();
1831-
vector<Expr> bounds = provGraph.deriveIterBounds(var, definedIndexVarsOrdered, underivedBounds,
1832-
indexVarToExprMap, iterators);
1833-
auto size = ir::Sub::make(bounds[1], bounds[0]);
1834-
temporarySizeMap[temporary] = {size};
1835-
return size;
1832+
} else if (temporarySize.isIndexVarSized()) {
1833+
IndexVar var = temporarySize.getIndexVarSize();
1834+
vector<Expr> bounds = provGraph.deriveIterBounds(var, definedIndexVarsOrdered, underivedBounds,
1835+
indexVarToExprMap, iterators);
1836+
size = ir::Sub::make(bounds[1], bounds[0]);
1837+
}
1838+
sizeVector.push_back(size);
1839+
if (i == 0)
1840+
finalSize = size;
1841+
else
1842+
finalSize = ir::Mul::make(finalSize, size);
18361843
}
1844+
temporarySizeMap[temporary] = sizeVector;
1845+
return finalSize;
18371846

18381847
taco_ierror; // TODO
18391848
return Expr();
@@ -2077,8 +2086,8 @@ vector<Stmt> LowererImplImperative::codeToInitializeTemporaryParallel(Where wher
20772086
values = ir::Var::make(temporaryAll.getName(),
20782087
temporaryAll.getType().getDataType(),
20792088
true, false);
2080-
taco_iassert(temporaryAll.getType().getOrder() == 1) << " Temporary order was "
2081-
<< temporaryAll.getType().getOrder(); // TODO
2089+
// taco_iassert(temporaryAll.getType().getOrder() == 1) << " Temporary order was "
2090+
// << temporaryAll.getType().getOrder(); // TODO
20822091
Expr size = getTemporarySize(where);
20832092
Expr sizeAll = ir::Mul::make(size, ir::Call::make("omp_get_max_threads", {}, size.type()));
20842093

src/parser/schedule_parser.cpp

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ vector<vector<string>> ScheduleParser(const string argValue) {
2828
parser::Lexer lexer(argValue);
2929
parser::Token tok;
3030
parenthesesCnt = 0;
31+
int curlyParenthesesCnt = 0;
32+
3133
for(tok = lexer.getToken(); tok != parser::Token::eot; tok = lexer.getToken()) {
3234
switch(tok) {
3335
case parser::Token::lparen:
@@ -49,7 +51,10 @@ vector<vector<string>> ScheduleParser(const string argValue) {
4951
parenthesesCnt--;
5052
break;
5153
case parser::Token::comma:
52-
if(parenthesesCnt == 0) {
54+
if (curlyParenthesesCnt > 0) {
55+
// multiple indexes inside of a {} list; pass it through
56+
current_element += lexer.tokenString(tok);
57+
} else if(parenthesesCnt == 0) {
5358
// new schedule directive
5459
current_schedule.push_back(current_element);
5560
parsed.push_back(current_schedule);
@@ -65,6 +70,17 @@ vector<vector<string>> ScheduleParser(const string argValue) {
6570
break;
6671
}
6772
break;
73+
case parser::Token::lcurly:
74+
// Keep track of curly brackets for list arguments
75+
current_element += lexer.tokenString(tok);
76+
curlyParenthesesCnt++;
77+
break;
78+
case parser::Token::rcurly:
79+
taco_uassert(curlyParenthesesCnt > 0) << "mismatched curly parentheses (too many right-curly-parens, "
80+
"negative nesting level) in schedule expression '" << argValue << "'";
81+
current_element += lexer.tokenString(tok);
82+
curlyParenthesesCnt--;
83+
break;
6884
// things where .getIdentifier() makes sense
6985
case parser::Token::identifier:
7086
case parser::Token::int_scalar:
@@ -87,6 +103,71 @@ vector<vector<string>> ScheduleParser(const string argValue) {
87103
return parsed;
88104
}
89105

106+
/// Parses command line lists for the scheduling directive 'precompute(expr, i_vars, iw_vars)'
107+
/// The lists are used for i_vars and iw_vars
108+
vector<string> varListParser(const string argValue) {
109+
vector<string> parsed;
110+
string current_element;
111+
parser::Lexer lexer(argValue);
112+
parser::Token tok;
113+
int curlyParenthesesCnt = 0;
114+
115+
for(tok = lexer.getToken(); tok != parser::Token::eot; tok = lexer.getToken()) {
116+
switch(tok) {
117+
case parser::Token::comma:
118+
if (curlyParenthesesCnt > 0) {
119+
// multiple indexes inside of a {} list; pass it through
120+
parsed.push_back(current_element);
121+
current_element = "";
122+
} else {
123+
// probably multiple indexes inside of an IndexExpr; pass it through
124+
current_element += lexer.tokenString(tok);
125+
break;
126+
}
127+
break;
128+
case parser::Token::lcurly:
129+
// Keep track of curly brackets for list arguments
130+
current_element = "";
131+
curlyParenthesesCnt++;
132+
break;
133+
case parser::Token::rcurly:
134+
taco_uassert(curlyParenthesesCnt > 0) << "mismatched curly parentheses (too many right-curly-parens, "
135+
"negative nesting level) in schedule expression '" << argValue << "'";
136+
if (curlyParenthesesCnt == 1) {
137+
parsed.push_back(current_element);
138+
current_element = "";
139+
}
140+
curlyParenthesesCnt--;
141+
break;
142+
case parser::Token::lparen:
143+
// ignore parenthesis
144+
break;
145+
case parser::Token::rparen:
146+
// ignore parenthesis
147+
break;
148+
// things where .getIdentifier() makes sense
149+
case parser::Token::identifier:
150+
case parser::Token::int_scalar:
151+
case parser::Token::uint_scalar:
152+
case parser::Token::float_scalar:
153+
case parser::Token::complex_scalar:
154+
current_element += lexer.getIdentifier();
155+
break;
156+
// .tokenstring() works for the remaining cases
157+
default:
158+
current_element += lexer.tokenString(tok);
159+
break;
160+
}
161+
}
162+
taco_uassert(curlyParenthesesCnt == 0) << "imbalanced curly brackets (too few right-curly brackets) in"
163+
" schedule expression '" << argValue << "'";
164+
if(current_element.length() > 0)
165+
parsed.push_back(current_element);
166+
return parsed;
167+
}
168+
169+
170+
90171
string serializeParsedSchedule(vector<vector<string>> parsed) {
91172
std::stringstream ss;
92173
ss << "[ ";

test/tests-workspaces.cpp

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ TEST(workspaces, tile_vecElemMul_NoTail) {
3636
.split(i_bounded, i0, i1, 4)
3737
.precompute(precomputedExpr, i1, i1, precomputed);
3838

39-
// cout << stmt << endl;
40-
4139
A.compile(stmt);
4240
A.assemble();
4341
A.compute();
@@ -159,16 +157,11 @@ TEST(workspaces, tile_denseMatMul) {
159157

160158
IndexStmt stmt = A.getAssignment().concretize();
161159
TensorVar precomputed("precomputed", Type(Float64, {Dimension(i1)}), taco::dense);
162-
cout << "------STMT1------" << endl;
163160
stmt = stmt.bound(i, i_bounded, 16, BoundType::MaxExact)
164161
.split(i_bounded, i0, i1, 4);
165-
cout << stmt << endl;
166162

167-
cout << "------STMT2------" << endl;
168163
stmt = stmt.precompute(precomputedExpr, i1, i1, precomputed);
169164

170-
cout << stmt << endl;
171-
172165
A.compile(stmt.concretize());
173166
A.assemble();
174167
A.compute();
@@ -313,7 +306,7 @@ TEST(workspaces, precompute4D_multireduce) {
313306
ASSERT_TENSOR_EQ(A, expected);
314307
}
315308

316-
TEST(workspaces, precompute3D_MspV) {
309+
TEST(workspaces, precompute3D_TspV) {
317310
int N = 16;
318311
Tensor<double> A("A", {N, N}, Format{Dense, Dense});
319312
Tensor<double> B("B", {N, N, N, N}, Format{Dense, Dense, Dense, Dense});
@@ -339,8 +332,7 @@ TEST(workspaces, precompute3D_MspV) {
339332
TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N, (size_t)N}), Format{Dense, Dense, Dense});
340333
stmt = stmt.precompute(precomputedExpr, {i, j, k}, {i, j, k}, ws);
341334
stmt = stmt.concretize();
342-
cout << "----------STMT----------" << endl;
343-
cout << stmt << endl;
335+
344336
A.compile(stmt);
345337
A.assemble();
346338
A.compute();
@@ -380,14 +372,10 @@ TEST(workspaces, precompute3D_multipleWS) {
380372
IndexStmt stmt = A.getAssignment().concretize();
381373
TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N, (size_t)N}), Format{Dense, Dense, Dense});
382374
TensorVar t("t", Type(Float64, {(size_t) N, (size_t)N}), Format{Dense, Dense});
383-
cout << "----------STMT1----------" << endl;
384375
stmt = stmt.precompute(precomputedExpr, {i, j, k}, {i, j, k}, ws);
385-
cout << stmt << endl;
386376

387-
cout << "----------STMT2----------" << endl;
388377
stmt = stmt.precompute(ws(i, j, k) * c(k), {i, j}, {i, j}, t);
389378
stmt = stmt.concretize();
390-
cout << stmt << endl;
391379

392380
A.compile(stmt);
393381
A.assemble();
@@ -402,3 +390,45 @@ TEST(workspaces, precompute3D_multipleWS) {
402390

403391
}
404392

393+
TEST(workspaces, precompute3D_renamedIVars_TspV) {
394+
int N = 16;
395+
Tensor<double> A("A", {N, N}, Format{Dense, Dense});
396+
Tensor<double> B("B", {N, N, N, N}, Format{Dense, Dense, Dense, Dense});
397+
Tensor<double> c("c", {N}, Format{Sparse});
398+
399+
for (int i = 0; i < N; i++) {
400+
c.insert({i}, (double) i);
401+
for (int j = 0; j < N; j++) {
402+
for (int k = 0; k < N; k++) {
403+
for (int l = 0; l < N; l++) {
404+
B.insert({i, j, k, l}, (double) i + j);
405+
}
406+
}
407+
}
408+
}
409+
410+
IndexVar i("i"), j("j"), k("k"), l("l");
411+
IndexExpr precomputedExpr = B(i, j, k, l) * c(l);
412+
A(i, j) = precomputedExpr * c(k);
413+
414+
415+
IndexStmt stmt = A.getAssignment().concretize();
416+
TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N, (size_t)N}),
417+
Format{Dense, Dense, Dense});
418+
419+
IndexVar iw("iw"), jw("jw"), kw("kw");
420+
stmt = stmt.precompute(precomputedExpr, {i, j, k}, {iw, jw, kw}, ws);
421+
stmt = stmt.concretize();
422+
423+
A.compile(stmt);
424+
A.assemble();
425+
A.compute();
426+
427+
Tensor<double> expected("expected", {N, N}, Format{Dense, Dense});
428+
expected(i, j) = (B(i, j, k, l) * c(l)) * c(k);
429+
expected.compile();
430+
expected.assemble();
431+
expected.compute();
432+
ASSERT_TENSOR_EQ(A, expected);
433+
434+
}

tools/taco.cpp

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -383,23 +383,37 @@ static bool setSchedulingCommands(vector<vector<string>> scheduleCommands, parse
383383
stmt = stmt.divide(findVar(i), divide1, divide2, divideFactor);
384384
} else if (command == "precompute") {
385385
string exprStr, i, iw, name;
386+
vector<string> i_vars, iw_vars;
387+
388+
for (auto& s : scheduleCommand) { cout << s << ", ";}
389+
cout << endl;
386390
taco_uassert(scheduleCommand.size() == 3 || scheduleCommand.size() == 4)
387391
<< "'precompute' scheduling directive takes 3 or 4 parameters: "
388-
<< "precompute(expr, i, iw [, workspace_name])";
392+
<< "precompute(expr, i, iw [, workspace_name]) or precompute(expr, {i_vars}, {iw_vars} [, workspace_name])"
393+
<< scheduleCommand.size();
394+
389395
exprStr = scheduleCommand[0];
390-
i = scheduleCommand[1];
391-
iw = scheduleCommand[2];
396+
// i = scheduleCommand[1];
397+
// iw = scheduleCommand[2];
398+
i_vars = parser::varListParser(scheduleCommand[1]);
399+
iw_vars = parser::varListParser(scheduleCommand[2]);
400+
392401
if (scheduleCommand.size() == 4)
393402
name = scheduleCommand[3];
394403
else
395404
name = "workspace";
396405

397-
IndexVar orig = findVar(i);
398-
IndexVar pre;
399-
try {
400-
pre = findVar(iw);
401-
} catch (TacoException &e) {
402-
pre = IndexVar(iw);
406+
vector<IndexVar> origs;
407+
vector<IndexVar> pres;
408+
for (auto& i : i_vars) {
409+
origs.push_back(findVar(i));
410+
}
411+
for (auto& iw : iw_vars) {
412+
try {
413+
pres.push_back(findVar(iw));
414+
} catch (TacoException &e) {
415+
pres.push_back(IndexVar(iw));
416+
}
403417
}
404418

405419
struct GetExpr : public IndexNotationVisitor {
@@ -456,17 +470,25 @@ static bool setSchedulingCommands(vector<vector<string>> scheduleCommands, parse
456470
visitor.setExprStr(exprStr);
457471
stmt.accept(&visitor);
458472

459-
Dimension dim;
473+
vector<Dimension> dims;
460474
auto domains = stmt.getIndexVarDomains();
461-
auto it = domains.find(orig);
462-
if (it != domains.end()) {
463-
dim = it->second;
464-
} else {
465-
dim = Dimension(orig);
475+
for (auto& orig : origs) {
476+
auto it = domains.find(orig);
477+
if (it != domains.end()) {
478+
dims.push_back(it->second);
479+
} else {
480+
dims.push_back(Dimension(orig));
481+
}
466482
}
467483

468-
TensorVar workspace(name, Type(Float64, {dim}), Dense);
469-
stmt = stmt.precompute(visitor.expr, orig, pre, workspace);
484+
std::vector<ModeFormatPack> modeFormatPacks(dims.size(), Dense);
485+
Format format(modeFormatPacks);
486+
TensorVar workspace(name, Type(Float64, dims), format);
487+
cout << "Parser ORDER: " << dims.size() << endl;
488+
for (auto& d : dims) {cout << d << ", ";}
489+
cout << endl;
490+
491+
stmt = stmt.precompute(visitor.expr, origs, pres, workspace);
470492

471493
} else if (command == "reorder") {
472494
taco_uassert(scheduleCommand.size() > 1) << "'reorder' scheduling directive needs at least 2 parameters: reorder(outermost, ..., innermost)";

0 commit comments

Comments
 (0)