Skip to content

Commit d0654a8

Browse files
authored
Merge pull request #508 from tensor-compiler/master_array_algebra
Merge Master array algebra into Master
2 parents 7cd2f29 + 0b38d98 commit d0654a8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+5556
-346
lines changed

include/taco.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "taco/tensor.h"
55
#include "taco/format.h"
6+
#include "taco/index_notation/tensor_operator.h"
67
#include "taco/index_notation/index_notation.h"
78

89
#endif

include/taco/index_notation/index_notation.h

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#include <set>
1010
#include <map>
1111
#include <utility>
12+
#include <functional>
1213

14+
#include "taco/util/name_generator.h"
1315
#include "taco/format.h"
1416
#include "taco/error.h"
1517
#include "taco/util/intrusive_ptr.h"
@@ -21,6 +23,7 @@
2123
#include "taco/index_notation/index_notation_nodes_abstract.h"
2224
#include "taco/ir_tags.h"
2325
#include "taco/index_notation/provenance_graph.h"
26+
#include "taco/index_notation/properties.h"
2427

2528
namespace taco {
2629

@@ -39,6 +42,8 @@ class IndexExpr;
3942
class Assignment;
4043
class Access;
4144

45+
class IterationAlgebra;
46+
4247
struct AccessNode;
4348
struct IndexVarIterationModifier;
4449
struct LiteralNode;
@@ -49,8 +54,10 @@ struct SubNode;
4954
struct MulNode;
5055
struct DivNode;
5156
struct CastNode;
57+
struct CallNode;
5258
struct CallIntrinsicNode;
5359
struct ReductionNode;
60+
struct IndexVarNode;
5461

5562
struct AssignmentNode;
5663
struct YieldNode;
@@ -231,7 +238,7 @@ class Access : public IndexExpr {
231238
Access() = default;
232239
Access(const Access&) = default;
233240
Access(const AccessNode*);
234-
Access(const TensorVar& tensorVar, const std::vector<IndexVar>& indices={},
241+
Access(const TensorVar& tensorVar, const std::vector<IndexVar>& indices={},
235242
const std::map<int, std::shared_ptr<IndexVarIterationModifier>>& modifiers={},
236243
bool isAccessingStructure=false);
237244

@@ -296,6 +303,11 @@ class Access : public IndexExpr {
296303
Assignment operator+=(const IndexExpr&);
297304

298305
typedef AccessNode Node;
306+
307+
// Equality and comparison are overridden on Access to perform a deep
308+
// comparison of the access rather than a pointer check.
309+
friend bool operator==(const Access& a, const Access& b);
310+
friend bool operator<(const Access& a, const Access &b);
299311
};
300312

301313

@@ -323,11 +335,14 @@ class Literal : public IndexExpr {
323335
Literal(std::complex<float>);
324336
Literal(std::complex<double>);
325337

326-
static IndexExpr zero(Datatype);
338+
static Literal zero(Datatype);
327339

328340
/// Returns the literal value.
329341
template <typename T> T getVal() const;
330342

343+
/// Returns an untyped pointer to the literal value
344+
void* getValPtr();
345+
331346
typedef LiteralNode Node;
332347
};
333348

@@ -447,6 +462,26 @@ class Cast : public IndexExpr {
447462
typedef CastNode Node;
448463
};
449464

465+
/// A call to an operator
466+
class Call: public IndexExpr {
467+
public:
468+
Call() = default;
469+
Call(const CallNode*);
470+
Call(const CallNode*, std::string name);
471+
472+
const std::vector<IndexExpr>& getArgs() const;
473+
const std::function<ir::Expr(const std::vector<ir::Expr>&)> getFunc() const;
474+
const IterationAlgebra& getAlgebra() const;
475+
const std::vector<Property>& getProperties() const;
476+
const std::string getName() const;
477+
const std::map<std::vector<int>, std::function<ir::Expr(const std::vector<ir::Expr>&)>> getDefs() const;
478+
const std::vector<int>& getDefinedArgs() const;
479+
480+
typedef CallNode Node;
481+
482+
private:
483+
std::string name;
484+
};
450485

451486
/// A call to an intrinsic.
452487
/// ```
@@ -467,6 +502,8 @@ class CallIntrinsic : public IndexExpr {
467502
typedef CallIntrinsicNode Node;
468503
};
469504

505+
std::ostream& operator<<(std::ostream&, const IndexVar&);
506+
470507
/// Create calls to various intrinsics.
471508
IndexExpr mod(IndexExpr, IndexExpr);
472509
IndexExpr abs(IndexExpr);
@@ -982,17 +1019,27 @@ class IndexSetVar : public util::Comparable<IndexSetVar>, public IndexVarInterfa
9821019

9831020
/// Index variables are used to index into tensors in index expressions, and
9841021
/// they represent iteration over the tensor modes they index into.
985-
class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
1022+
class IndexVar : public IndexExpr, public IndexVarInterface {
1023+
9861024
public:
9871025
IndexVar();
9881026
~IndexVar() = default;
9891027
IndexVar(const std::string& name);
1028+
IndexVar(const std::string& name, const Datatype& type);
1029+
IndexVar(const IndexVarNode *);
9901030

9911031
/// Returns the name of the index variable.
9921032
std::string getName() const;
9931033

1034+
// Need these to overshadow the comparisons in for the IndexExpr instrusive pointer
9941035
friend bool operator==(const IndexVar&, const IndexVar&);
9951036
friend bool operator<(const IndexVar&, const IndexVar&);
1037+
friend bool operator!=(const IndexVar&, const IndexVar&);
1038+
friend bool operator>=(const IndexVar&, const IndexVar&);
1039+
friend bool operator<=(const IndexVar&, const IndexVar&);
1040+
friend bool operator>(const IndexVar&, const IndexVar&);
1041+
1042+
typedef IndexVarNode Node;
9961043

9971044
/// Indexing into an IndexVar returns a window into it.
9981045
WindowedIndexVar operator()(int lo, int hi, int stride = 1);
@@ -1049,11 +1096,12 @@ SuchThat suchthat(IndexStmt stmt, std::vector<IndexVarRel> predicate);
10491096
class TensorVar : public util::Comparable<TensorVar> {
10501097
public:
10511098
TensorVar();
1052-
TensorVar(const Type& type);
1053-
TensorVar(const std::string& name, const Type& type);
1054-
TensorVar(const Type& type, const Format& format);
1055-
TensorVar(const std::string& name, const Type& type, const Format& format);
1056-
TensorVar(const int &id, const std::string& name, const Type& type, const Format& format);
1099+
TensorVar(const Type& type, const Literal& fill = Literal());
1100+
TensorVar(const std::string& name, const Type& type, const Literal& fill = Literal());
1101+
TensorVar(const Type& type, const Format& format, const Literal& fill = Literal());
1102+
TensorVar(const std::string& name, const Type& type, const Format& format, const Literal& fill = Literal());
1103+
TensorVar(const int &id, const std::string& name, const Type& type, const Format& format,
1104+
const Literal& fill = Literal());
10571105

10581106
/// Returns the ID of the tensor variable.
10591107
int getId() const;
@@ -1074,6 +1122,12 @@ class TensorVar : public util::Comparable<TensorVar> {
10741122
/// and execute it's expression.
10751123
const Schedule& getSchedule() const;
10761124

1125+
/// Gets the fill value of the tensor variable. May be left undefined.
1126+
const Literal& getFill() const;
1127+
1128+
/// Set the fill value of the tensor variable
1129+
void setFill(const Literal& fill);
1130+
10771131
/// Set the name of the tensor variable.
10781132
void setName(std::string name);
10791133

@@ -1134,7 +1188,8 @@ bool isReductionNotation(IndexStmt, std::string* reason=nullptr);
11341188
bool isReductionNotationScheduled(IndexStmt, ProvenanceGraph, std::string* reason=nullptr);
11351189

11361190
/// Check whether the statement is in the concrete index notation dialect.
1137-
/// This means every index variable has a forall node, there are no reduction
1191+
/// This means every index variable has a forall node, each index variable used
1192+
/// for computation is under a forall node for that variable, there are no reduction
11381193
/// nodes, and that every reduction variable use is nested inside a compound
11391194
/// assignment statement. You can optionally pass in a pointer to a string
11401195
/// that the reason why it is not concrete notation is printed to.
@@ -1168,7 +1223,12 @@ std::vector<TensorVar> getResults(IndexStmt stmt);
11681223
/// Returns the input tensors to the index statement, in the order they appear.
11691224
std::vector<TensorVar> getArguments(IndexStmt stmt);
11701225

1171-
/// Returns the temporaries in the index statement, in the order they appear.
1226+
/// Returns true iff all of the loops over free variables come before all of the loops over
1227+
/// reduction variables. Therefore, this returns true if the reduction controlled by the loops
1228+
/// does not a scatter.
1229+
bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt);
1230+
1231+
/// Returns the temporaries in the index statement, in the order they appear.
11721232
std::vector<TensorVar> getTemporaries(IndexStmt stmt);
11731233

11741234
/// Returns the attribute query results in the index statement, in the order
@@ -1220,7 +1280,15 @@ IndexExpr zero(IndexExpr, const std::set<Access>& zeroed);
12201280
/// zero and then propagating and removing zeroes.
12211281
IndexStmt zero(IndexStmt, const std::set<Access>& zeroed);
12221282

1223-
/// Create an `other` tensor with the given name and format,
1283+
/// Infers the fill value of the input expression by applying properties if possible. If unable
1284+
/// to successfully infer the fill value of the result, returns the empty IndexExpr
1285+
IndexExpr inferFill(IndexExpr);
1286+
1287+
/// Returns true if there are no forall nodes in the indexStmt. Used to check
1288+
/// if the last loop is being lowered.
1289+
bool hasNoForAlls(IndexStmt);
1290+
1291+
/// Create an `other` tensor with the given name and format,
12241292
/// and return tensor(indexVars) = other(indexVars) if otherIsOnRight,
12251293
/// and otherwise returns other(indexVars) = tensor(indexVars).
12261294
IndexStmt generatePackStmt(TensorVar tensor,

include/taco/index_notation/index_notation_nodes.h

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,21 @@
44
#include <vector>
55
#include <memory>
66
#include <functional>
7+
#include <numeric>
8+
#include <functional>
79

10+
#include "taco/type.h"
11+
#include "taco/util/collections.h"
12+
#include "taco/util/comparable.h"
813
#include "taco/type.h"
914
#include "taco/tensor.h"
1015
#include "taco/index_notation/index_notation.h"
1116
#include "taco/index_notation/index_notation_nodes_abstract.h"
1217
#include "taco/index_notation/index_notation_visitor.h"
1318
#include "taco/index_notation/intrinsic.h"
1419
#include "taco/util/strings.h"
20+
#include "iteration_algebra.h"
21+
#include "properties.h"
1522

1623
namespace taco {
1724

@@ -55,6 +62,15 @@ struct AccessWindow : IndexVarIterationModifier {
5562
friend bool operator==(const AccessWindow& a, const AccessWindow& b) {
5663
return a.lo == b.lo && a.hi == b.hi && a.stride == b.stride;
5764
}
65+
friend bool operator<(const AccessWindow& a, const AccessWindow& b) {
66+
if (a.lo != b.lo) {
67+
return a.lo < b.lo;
68+
}
69+
if (a.hi != b.hi) {
70+
return a.hi < b.hi;
71+
}
72+
return a.stride < b.stride;
73+
}
5874
};
5975

6076
// An AccessNode also carries the information about an index set for an IndexVar +
@@ -68,10 +84,16 @@ struct IndexSet : IndexVarIterationModifier {
6884
friend bool operator==(const IndexSet& a, const IndexSet& b) {
6985
return *a.set == *b.set && a.tensor == b.tensor;
7086
}
87+
friend bool operator<(const IndexSet& a, const IndexSet& b) {
88+
if (*a.set < *b.set) {
89+
return *a.set < *b.set;
90+
}
91+
return a.tensor < b.tensor;
92+
}
7193
};
7294

7395
struct AccessNode : public IndexExprNode {
74-
AccessNode(TensorVar tensorVar, const std::vector<IndexVar>& indices,
96+
AccessNode(TensorVar tensorVar, const std::vector<IndexVar>& indices,
7597
const std::map<int, std::shared_ptr<IndexVarIterationModifier>> &modifiers,
7698
bool isAccessingStructure)
7799
: IndexExprNode(isAccessingStructure ? Bool : tensorVar.getType().getDataType()),
@@ -143,7 +165,6 @@ struct LiteralNode : public IndexExprNode {
143165
void* val;
144166
};
145167

146-
147168
struct UnaryExprNode : public IndexExprNode {
148169
IndexExpr a;
149170

@@ -263,6 +284,57 @@ struct CallIntrinsicNode : public IndexExprNode {
263284
std::vector<IndexExpr> args;
264285
};
265286

287+
struct CallNode : public IndexExprNode {
288+
typedef std::function<ir::Expr(const std::vector<ir::Expr>&)> OpImpl;
289+
typedef std::function<IterationAlgebra(const std::vector<IndexExpr>&)> AlgebraImpl;
290+
291+
CallNode(std::string name, const std::vector<IndexExpr>& args, OpImpl lowerFunc,
292+
const IterationAlgebra& iterAlg,
293+
const std::vector<Property>& properties,
294+
const std::map<std::vector<int>, OpImpl>& regionDefinitions,
295+
const std::vector<int>& definedRegions);
296+
297+
CallNode(std::string name, const std::vector<IndexExpr>& args, OpImpl lowerFunc,
298+
const IterationAlgebra& iterAlg,
299+
const std::vector<Property>& properties,
300+
const std::map<std::vector<int>, OpImpl>& regionDefinitions);
301+
302+
void accept(IndexExprVisitorStrict* v) const {
303+
v->visit(this);
304+
}
305+
306+
std::string name;
307+
std::vector<IndexExpr> args;
308+
OpImpl defaultLowerFunc;
309+
IterationAlgebra iterAlg;
310+
std::vector<Property> properties;
311+
std::map<std::vector<int>, OpImpl> regionDefinitions;
312+
313+
// Needed to track which inputs have been exhausted so the lowerer can know which lower func to use
314+
std::vector<int> definedRegions;
315+
316+
private:
317+
static Datatype inferReturnType(OpImpl f, const std::vector<IndexExpr>& inputs) {
318+
std::function<ir::Expr(IndexExpr)> getExprs = [](IndexExpr arg) { return ir::Var::make("t", arg.getDataType()); };
319+
std::vector<ir::Expr> exprs = util::map(inputs, getExprs);
320+
321+
if(exprs.empty()) {
322+
return taco::Datatype();
323+
}
324+
325+
return f(exprs).type();
326+
}
327+
328+
static std::vector<int> definedIndices(std::vector<IndexExpr> args) {
329+
std::vector<int> v;
330+
for(int i = 0; i < (int) args.size(); ++i) {
331+
if(args[i].defined()) {
332+
v.push_back(i);
333+
}
334+
}
335+
return v;
336+
}
337+
};
266338

267339
struct ReductionNode : public IndexExprNode {
268340
ReductionNode(IndexExpr op, IndexVar var, IndexExpr a);
@@ -277,6 +349,27 @@ struct ReductionNode : public IndexExprNode {
277349
IndexExpr a;
278350
};
279351

352+
struct IndexVarNode : public IndexExprNode, public util::Comparable<IndexVarNode> {
353+
IndexVarNode() = delete;
354+
IndexVarNode(const std::string& name, const Datatype& type);
355+
356+
void accept(IndexExprVisitorStrict* v) const {
357+
v->visit(this);
358+
}
359+
360+
std::string getName() const;
361+
362+
friend bool operator==(const IndexVarNode& a, const IndexVarNode& b);
363+
friend bool operator<(const IndexVarNode& a, const IndexVarNode& b);
364+
365+
private:
366+
struct Content;
367+
std::shared_ptr<Content> content;
368+
};
369+
370+
struct IndexVarNode::Content {
371+
std::string name;
372+
};
280373

281374
// Index Statements
282375
struct AssignmentNode : public IndexStmtNode {

include/taco/index_notation/index_notation_printer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ class IndexNotationPrinter : public IndexNotationVisitorStrict {
2525
void visit(const MulNode*);
2626
void visit(const DivNode*);
2727
void visit(const CastNode*);
28+
void visit(const CallNode*);
2829
void visit(const CallIntrinsicNode*);
2930
void visit(const ReductionNode*);
31+
void visit(const IndexVarNode*);
3032

3133
// Tensor Expressions
3234
void visit(const AssignmentNode*);

0 commit comments

Comments
 (0)