Skip to content

Commit 2322259

Browse files
committed
Merge branch 'array_algebra' into master_array_algebra
2 parents 96031e5 + 3fc8a46 commit 2322259

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+5364
-320
lines changed

include/taco.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "taco/tensor.h"
55
#include "taco/format.h"
6+
#include "taco/index_notation/tensor_operator.h"
67
#include "taco/index_notation/index_notation.h"
78

89
#endif

include/taco/index_notation/index_notation.h

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#include <set>
1010
#include <map>
1111
#include <utility>
12+
#include <functional>
1213

14+
#include "taco/util/name_generator.h"
1315
#include "taco/format.h"
1416
#include "taco/error.h"
1517
#include "taco/util/intrusive_ptr.h"
@@ -21,6 +23,7 @@
2123
#include "taco/index_notation/index_notation_nodes_abstract.h"
2224
#include "taco/ir_tags.h"
2325
#include "taco/index_notation/provenance_graph.h"
26+
#include "taco/index_notation/properties.h"
2427

2528
namespace taco {
2629

@@ -38,6 +41,8 @@ class IndexExpr;
3841
class Assignment;
3942
class Access;
4043

44+
class IterationAlgebra;
45+
4146
struct AccessNode;
4247
struct IndexVarIterationModifier;
4348
struct LiteralNode;
@@ -48,8 +53,10 @@ struct SubNode;
4853
struct MulNode;
4954
struct DivNode;
5055
struct CastNode;
56+
struct CallNode;
5157
struct CallIntrinsicNode;
5258
struct ReductionNode;
59+
struct IndexVarNode;
5360

5461
struct AssignmentNode;
5562
struct YieldNode;
@@ -224,7 +231,7 @@ class Access : public IndexExpr {
224231
Access() = default;
225232
Access(const Access&) = default;
226233
Access(const AccessNode*);
227-
Access(const TensorVar& tensorVar, const std::vector<IndexVar>& indices={},
234+
Access(const TensorVar& tensorVar, const std::vector<IndexVar>& indices={},
228235
const std::map<int, std::shared_ptr<IndexVarIterationModifier>>& modifiers={},
229236
bool isAccessingStructure=false);
230237

@@ -289,6 +296,11 @@ class Access : public IndexExpr {
289296
Assignment operator+=(const IndexExpr&);
290297

291298
typedef AccessNode Node;
299+
300+
// Equality and comparison are overridden on Access to perform a deep
301+
// comparison of the access rather than a pointer check.
302+
friend bool operator==(const Access& a, const Access& b);
303+
friend bool operator<(const Access& a, const Access &b);
292304
};
293305

294306

@@ -316,11 +328,14 @@ class Literal : public IndexExpr {
316328
Literal(std::complex<float>);
317329
Literal(std::complex<double>);
318330

319-
static IndexExpr zero(Datatype);
331+
static Literal zero(Datatype);
320332

321333
/// Returns the literal value.
322334
template <typename T> T getVal() const;
323335

336+
/// Returns an untyped pointer to the literal value
337+
void* getValPtr();
338+
324339
typedef LiteralNode Node;
325340
};
326341

@@ -440,6 +455,26 @@ class Cast : public IndexExpr {
440455
typedef CastNode Node;
441456
};
442457

458+
/// A call to an operator
459+
class Call: public IndexExpr {
460+
public:
461+
Call() = default;
462+
Call(const CallNode*);
463+
Call(const CallNode*, std::string name);
464+
465+
const std::vector<IndexExpr>& getArgs() const;
466+
const std::function<ir::Expr(const std::vector<ir::Expr>&)> getFunc() const;
467+
const IterationAlgebra& getAlgebra() const;
468+
const std::vector<Property>& getProperties() const;
469+
const std::string getName() const;
470+
const std::map<std::vector<int>, std::function<ir::Expr(const std::vector<ir::Expr>&)>> getDefs() const;
471+
const std::vector<int>& getDefinedArgs() const;
472+
473+
typedef CallNode Node;
474+
475+
private:
476+
std::string name;
477+
};
443478

444479
/// A call to an intrinsic.
445480
/// ```
@@ -460,6 +495,8 @@ class CallIntrinsic : public IndexExpr {
460495
typedef CallIntrinsicNode Node;
461496
};
462497

498+
std::ostream& operator<<(std::ostream&, const IndexVar&);
499+
463500
/// Create calls to various intrinsics.
464501
IndexExpr mod(IndexExpr, IndexExpr);
465502
IndexExpr abs(IndexExpr);
@@ -951,17 +988,27 @@ class IndexSetVar : public util::Comparable<IndexSetVar>, public IndexVarInterfa
951988

952989
/// Index variables are used to index into tensors in index expressions, and
953990
/// they represent iteration over the tensor modes they index into.
954-
class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
991+
class IndexVar : public IndexExpr, public IndexVarInterface {
992+
955993
public:
956994
IndexVar();
957995
~IndexVar() = default;
958996
IndexVar(const std::string& name);
997+
IndexVar(const std::string& name, const Datatype& type);
998+
IndexVar(const IndexVarNode *);
959999

9601000
/// Returns the name of the index variable.
9611001
std::string getName() const;
9621002

1003+
// Need these to overshadow the comparisons in for the IndexExpr instrusive pointer
9631004
friend bool operator==(const IndexVar&, const IndexVar&);
9641005
friend bool operator<(const IndexVar&, const IndexVar&);
1006+
friend bool operator!=(const IndexVar&, const IndexVar&);
1007+
friend bool operator>=(const IndexVar&, const IndexVar&);
1008+
friend bool operator<=(const IndexVar&, const IndexVar&);
1009+
friend bool operator>(const IndexVar&, const IndexVar&);
1010+
1011+
typedef IndexVarNode Node;
9651012

9661013
/// Indexing into an IndexVar returns a window into it.
9671014
WindowedIndexVar operator()(int lo, int hi, int stride = 1);
@@ -1018,11 +1065,12 @@ SuchThat suchthat(IndexStmt stmt, std::vector<IndexVarRel> predicate);
10181065
class TensorVar : public util::Comparable<TensorVar> {
10191066
public:
10201067
TensorVar();
1021-
TensorVar(const Type& type);
1022-
TensorVar(const std::string& name, const Type& type);
1023-
TensorVar(const Type& type, const Format& format);
1024-
TensorVar(const std::string& name, const Type& type, const Format& format);
1025-
TensorVar(const int &id, const std::string& name, const Type& type, const Format& format);
1068+
TensorVar(const Type& type, const Literal& fill = Literal());
1069+
TensorVar(const std::string& name, const Type& type, const Literal& fill = Literal());
1070+
TensorVar(const Type& type, const Format& format, const Literal& fill = Literal());
1071+
TensorVar(const std::string& name, const Type& type, const Format& format, const Literal& fill = Literal());
1072+
TensorVar(const int &id, const std::string& name, const Type& type, const Format& format,
1073+
const Literal& fill = Literal());
10261074

10271075
/// Returns the ID of the tensor variable.
10281076
int getId() const;
@@ -1043,6 +1091,12 @@ class TensorVar : public util::Comparable<TensorVar> {
10431091
/// and execute it's expression.
10441092
const Schedule& getSchedule() const;
10451093

1094+
/// Gets the fill value of the tensor variable. May be left undefined.
1095+
const Literal& getFill() const;
1096+
1097+
/// Set the fill value of the tensor variable
1098+
void setFill(const Literal& fill);
1099+
10461100
/// Set the name of the tensor variable.
10471101
void setName(std::string name);
10481102

@@ -1099,7 +1153,8 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr);
10991153
bool isReductionNotation(IndexStmt, std::string* reason=nullptr);
11001154

11011155
/// Check whether the statement is in the concrete index notation dialect.
1102-
/// This means every index variable has a forall node, there are no reduction
1156+
/// This means every index variable has a forall node, each index variable used
1157+
/// for computation is under a forall node for that variable, there are no reduction
11031158
/// nodes, and that every reduction variable use is nested inside a compound
11041159
/// assignment statement. You can optionally pass in a pointer to a string
11051160
/// that the reason why it is not concrete notation is printed to.
@@ -1121,7 +1176,12 @@ std::vector<TensorVar> getResults(IndexStmt stmt);
11211176
/// Returns the input tensors to the index statement, in the order they appear.
11221177
std::vector<TensorVar> getArguments(IndexStmt stmt);
11231178

1124-
/// Returns the temporaries in the index statement, in the order they appear.
1179+
/// Returns true iff all of the loops over free variables come before all of the loops over
1180+
/// reduction variables. Therefore, this returns true if the reduction controlled by the loops
1181+
/// does not a scatter.
1182+
bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt);
1183+
1184+
/// Returns the temporaries in the index statement, in the order they appear.
11251185
std::vector<TensorVar> getTemporaries(IndexStmt stmt);
11261186

11271187
/// Returns the attribute query results in the index statement, in the order
@@ -1173,7 +1233,15 @@ IndexExpr zero(IndexExpr, const std::set<Access>& zeroed);
11731233
/// zero and then propagating and removing zeroes.
11741234
IndexStmt zero(IndexStmt, const std::set<Access>& zeroed);
11751235

1176-
/// Create an `other` tensor with the given name and format,
1236+
/// Infers the fill value of the input expression by applying properties if possible. If unable
1237+
/// to successfully infer the fill value of the result, returns the empty IndexExpr
1238+
IndexExpr inferFill(IndexExpr);
1239+
1240+
/// Returns true if there are no forall nodes in the indexStmt. Used to check
1241+
/// if the last loop is being lowered.
1242+
bool hasNoForAlls(IndexStmt);
1243+
1244+
/// Create an `other` tensor with the given name and format,
11771245
/// and return tensor(indexVars) = other(indexVars) if otherIsOnRight,
11781246
/// and otherwise returns other(indexVars) = tensor(indexVars).
11791247
IndexStmt generatePackStmt(TensorVar tensor,

include/taco/index_notation/index_notation_nodes.h

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,21 @@
44
#include <vector>
55
#include <memory>
66
#include <functional>
7+
#include <numeric>
8+
#include <functional>
79

10+
#include "taco/type.h"
11+
#include "taco/util/collections.h"
12+
#include "taco/util/comparable.h"
813
#include "taco/type.h"
914
#include "taco/tensor.h"
1015
#include "taco/index_notation/index_notation.h"
1116
#include "taco/index_notation/index_notation_nodes_abstract.h"
1217
#include "taco/index_notation/index_notation_visitor.h"
1318
#include "taco/index_notation/intrinsic.h"
1419
#include "taco/util/strings.h"
20+
#include "iteration_algebra.h"
21+
#include "properties.h"
1522

1623
namespace taco {
1724

@@ -55,6 +62,15 @@ struct AccessWindow : IndexVarIterationModifier {
5562
friend bool operator==(const AccessWindow& a, const AccessWindow& b) {
5663
return a.lo == b.lo && a.hi == b.hi && a.stride == b.stride;
5764
}
65+
friend bool operator<(const AccessWindow& a, const AccessWindow& b) {
66+
if (a.lo != b.lo) {
67+
return a.lo < b.lo;
68+
}
69+
if (a.hi != b.hi) {
70+
return a.hi < b.hi;
71+
}
72+
return a.stride < b.stride;
73+
}
5874
};
5975

6076
// An AccessNode also carries the information about an index set for an IndexVar +
@@ -68,10 +84,16 @@ struct IndexSet : IndexVarIterationModifier {
6884
friend bool operator==(const IndexSet& a, const IndexSet& b) {
6985
return *a.set == *b.set && a.tensor == b.tensor;
7086
}
87+
friend bool operator<(const IndexSet& a, const IndexSet& b) {
88+
if (*a.set < *b.set) {
89+
return *a.set < *b.set;
90+
}
91+
return a.tensor < b.tensor;
92+
}
7193
};
7294

7395
struct AccessNode : public IndexExprNode {
74-
AccessNode(TensorVar tensorVar, const std::vector<IndexVar>& indices,
96+
AccessNode(TensorVar tensorVar, const std::vector<IndexVar>& indices,
7597
const std::map<int, std::shared_ptr<IndexVarIterationModifier>> &modifiers,
7698
bool isAccessingStructure)
7799
: IndexExprNode(isAccessingStructure ? Bool : tensorVar.getType().getDataType()),
@@ -143,7 +165,6 @@ struct LiteralNode : public IndexExprNode {
143165
void* val;
144166
};
145167

146-
147168
struct UnaryExprNode : public IndexExprNode {
148169
IndexExpr a;
149170

@@ -263,6 +284,57 @@ struct CallIntrinsicNode : public IndexExprNode {
263284
std::vector<IndexExpr> args;
264285
};
265286

287+
struct CallNode : public IndexExprNode {
288+
typedef std::function<ir::Expr(const std::vector<ir::Expr>&)> OpImpl;
289+
typedef std::function<IterationAlgebra(const std::vector<IndexExpr>&)> AlgebraImpl;
290+
291+
CallNode(std::string name, const std::vector<IndexExpr>& args, OpImpl lowerFunc,
292+
const IterationAlgebra& iterAlg,
293+
const std::vector<Property>& properties,
294+
const std::map<std::vector<int>, OpImpl>& regionDefinitions,
295+
const std::vector<int>& definedRegions);
296+
297+
CallNode(std::string name, const std::vector<IndexExpr>& args, OpImpl lowerFunc,
298+
const IterationAlgebra& iterAlg,
299+
const std::vector<Property>& properties,
300+
const std::map<std::vector<int>, OpImpl>& regionDefinitions);
301+
302+
void accept(IndexExprVisitorStrict* v) const {
303+
v->visit(this);
304+
}
305+
306+
std::string name;
307+
std::vector<IndexExpr> args;
308+
OpImpl defaultLowerFunc;
309+
IterationAlgebra iterAlg;
310+
std::vector<Property> properties;
311+
std::map<std::vector<int>, OpImpl> regionDefinitions;
312+
313+
// Needed to track which inputs have been exhausted so the lowerer can know which lower func to use
314+
std::vector<int> definedRegions;
315+
316+
private:
317+
static Datatype inferReturnType(OpImpl f, const std::vector<IndexExpr>& inputs) {
318+
std::function<ir::Expr(IndexExpr)> getExprs = [](IndexExpr arg) { return ir::Var::make("t", arg.getDataType()); };
319+
std::vector<ir::Expr> exprs = util::map(inputs, getExprs);
320+
321+
if(exprs.empty()) {
322+
return taco::Datatype();
323+
}
324+
325+
return f(exprs).type();
326+
}
327+
328+
static std::vector<int> definedIndices(std::vector<IndexExpr> args) {
329+
std::vector<int> v;
330+
for(int i = 0; i < (int) args.size(); ++i) {
331+
if(args[i].defined()) {
332+
v.push_back(i);
333+
}
334+
}
335+
return v;
336+
}
337+
};
266338

267339
struct ReductionNode : public IndexExprNode {
268340
ReductionNode(IndexExpr op, IndexVar var, IndexExpr a);
@@ -277,6 +349,27 @@ struct ReductionNode : public IndexExprNode {
277349
IndexExpr a;
278350
};
279351

352+
struct IndexVarNode : public IndexExprNode, public util::Comparable<IndexVarNode> {
353+
IndexVarNode() = delete;
354+
IndexVarNode(const std::string& name, const Datatype& type);
355+
356+
void accept(IndexExprVisitorStrict* v) const {
357+
v->visit(this);
358+
}
359+
360+
std::string getName() const;
361+
362+
friend bool operator==(const IndexVarNode& a, const IndexVarNode& b);
363+
friend bool operator<(const IndexVarNode& a, const IndexVarNode& b);
364+
365+
private:
366+
struct Content;
367+
std::shared_ptr<Content> content;
368+
};
369+
370+
struct IndexVarNode::Content {
371+
std::string name;
372+
};
280373

281374
// Index Statements
282375
struct AssignmentNode : public IndexStmtNode {

include/taco/index_notation/index_notation_printer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ class IndexNotationPrinter : public IndexNotationVisitorStrict {
2525
void visit(const MulNode*);
2626
void visit(const DivNode*);
2727
void visit(const CastNode*);
28+
void visit(const CallNode*);
2829
void visit(const CallIntrinsicNode*);
2930
void visit(const ReductionNode*);
31+
void visit(const IndexVarNode*);
3032

3133
// Tensor Expressions
3234
void visit(const AssignmentNode*);

0 commit comments

Comments
 (0)