99#include < set>
1010#include < map>
1111#include < utility>
12+ #include < functional>
1213
14+ #include " taco/util/name_generator.h"
1315#include " taco/format.h"
1416#include " taco/error.h"
1517#include " taco/util/intrusive_ptr.h"
2123#include " taco/index_notation/index_notation_nodes_abstract.h"
2224#include " taco/ir_tags.h"
2325#include " taco/index_notation/provenance_graph.h"
26+ #include " taco/index_notation/properties.h"
2427
2528namespace taco {
2629
@@ -38,6 +41,8 @@ class IndexExpr;
3841class Assignment ;
3942class Access ;
4043
44+ class IterationAlgebra ;
45+
4146struct AccessNode ;
4247struct IndexVarIterationModifier ;
4348struct LiteralNode ;
@@ -48,8 +53,10 @@ struct SubNode;
4853struct MulNode ;
4954struct DivNode ;
5055struct CastNode ;
56+ struct CallNode ;
5157struct CallIntrinsicNode ;
5258struct ReductionNode ;
59+ struct IndexVarNode ;
5360
5461struct AssignmentNode ;
5562struct YieldNode ;
@@ -224,7 +231,7 @@ class Access : public IndexExpr {
224231 Access () = default ;
225232 Access (const Access&) = default ;
226233 Access (const AccessNode*);
227- Access (const TensorVar& tensorVar, const std::vector<IndexVar>& indices={},
234+ Access (const TensorVar& tensorVar, const std::vector<IndexVar>& indices={},
228235 const std::map<int , std::shared_ptr<IndexVarIterationModifier>>& modifiers={},
229236 bool isAccessingStructure=false );
230237
@@ -289,6 +296,11 @@ class Access : public IndexExpr {
289296 Assignment operator +=(const IndexExpr&);
290297
291298 typedef AccessNode Node;
299+
300+ // Equality and comparison are overridden on Access to perform a deep
301+ // comparison of the access rather than a pointer check.
302+ friend bool operator ==(const Access& a, const Access& b);
303+ friend bool operator <(const Access& a, const Access &b);
292304};
293305
294306
@@ -316,11 +328,14 @@ class Literal : public IndexExpr {
316328 Literal (std::complex <float >);
317329 Literal (std::complex <double >);
318330
319- static IndexExpr zero (Datatype);
331+ static Literal zero (Datatype);
320332
321333 // / Returns the literal value.
322334 template <typename T> T getVal () const ;
323335
336+ // / Returns an untyped pointer to the literal value
337+ void * getValPtr ();
338+
324339 typedef LiteralNode Node;
325340};
326341
@@ -440,6 +455,26 @@ class Cast : public IndexExpr {
440455 typedef CastNode Node;
441456};
442457
458+ // / A call to an operator
459+ class Call : public IndexExpr {
460+ public:
461+ Call () = default ;
462+ Call (const CallNode*);
463+ Call (const CallNode*, std::string name);
464+
465+ const std::vector<IndexExpr>& getArgs () const ;
466+ const std::function<ir::Expr(const std::vector<ir::Expr>&)> getFunc () const ;
467+ const IterationAlgebra& getAlgebra () const ;
468+ const std::vector<Property>& getProperties () const ;
469+ const std::string getName () const ;
470+ const std::map<std::vector<int >, std::function<ir::Expr(const std::vector<ir::Expr>&)>> getDefs () const ;
471+ const std::vector<int >& getDefinedArgs () const ;
472+
473+ typedef CallNode Node;
474+
475+ private:
476+ std::string name;
477+ };
443478
444479// / A call to an intrinsic.
445480// / ```
@@ -460,6 +495,8 @@ class CallIntrinsic : public IndexExpr {
460495 typedef CallIntrinsicNode Node;
461496};
462497
498+ std::ostream& operator <<(std::ostream&, const IndexVar&);
499+
463500// / Create calls to various intrinsics.
464501IndexExpr mod (IndexExpr, IndexExpr);
465502IndexExpr abs (IndexExpr);
@@ -951,17 +988,27 @@ class IndexSetVar : public util::Comparable<IndexSetVar>, public IndexVarInterfa
951988
952989// / Index variables are used to index into tensors in index expressions, and
953990// / they represent iteration over the tensor modes they index into.
954- class IndexVar : public util ::Comparable<IndexVar>, public IndexVarInterface {
991+ class IndexVar : public IndexExpr , public IndexVarInterface {
992+
955993public:
956994 IndexVar ();
957995 ~IndexVar () = default ;
958996 IndexVar (const std::string& name);
997+ IndexVar (const std::string& name, const Datatype& type);
998+ IndexVar (const IndexVarNode *);
959999
9601000 // / Returns the name of the index variable.
9611001 std::string getName () const ;
9621002
1003+ // Need these to overshadow the comparisons in for the IndexExpr instrusive pointer
9631004 friend bool operator ==(const IndexVar&, const IndexVar&);
9641005 friend bool operator <(const IndexVar&, const IndexVar&);
1006+ friend bool operator !=(const IndexVar&, const IndexVar&);
1007+ friend bool operator >=(const IndexVar&, const IndexVar&);
1008+ friend bool operator <=(const IndexVar&, const IndexVar&);
1009+ friend bool operator >(const IndexVar&, const IndexVar&);
1010+
1011+ typedef IndexVarNode Node;
9651012
9661013 // / Indexing into an IndexVar returns a window into it.
9671014 WindowedIndexVar operator ()(int lo, int hi, int stride = 1 );
@@ -1018,11 +1065,12 @@ SuchThat suchthat(IndexStmt stmt, std::vector<IndexVarRel> predicate);
10181065class TensorVar : public util ::Comparable<TensorVar> {
10191066public:
10201067 TensorVar ();
1021- TensorVar (const Type& type);
1022- TensorVar (const std::string& name, const Type& type);
1023- TensorVar (const Type& type, const Format& format);
1024- TensorVar (const std::string& name, const Type& type, const Format& format);
1025- TensorVar (const int &id, const std::string& name, const Type& type, const Format& format);
1068+ TensorVar (const Type& type, const Literal& fill = Literal());
1069+ TensorVar (const std::string& name, const Type& type, const Literal& fill = Literal());
1070+ TensorVar (const Type& type, const Format& format, const Literal& fill = Literal());
1071+ TensorVar (const std::string& name, const Type& type, const Format& format, const Literal& fill = Literal());
1072+ TensorVar (const int &id, const std::string& name, const Type& type, const Format& format,
1073+ const Literal& fill = Literal());
10261074
10271075 // / Returns the ID of the tensor variable.
10281076 int getId () const ;
@@ -1043,6 +1091,12 @@ class TensorVar : public util::Comparable<TensorVar> {
10431091 // / and execute it's expression.
10441092 const Schedule& getSchedule () const ;
10451093
1094+ // / Gets the fill value of the tensor variable. May be left undefined.
1095+ const Literal& getFill () const ;
1096+
1097+ // / Set the fill value of the tensor variable
1098+ void setFill (const Literal& fill);
1099+
10461100 // / Set the name of the tensor variable.
10471101 void setName (std::string name);
10481102
@@ -1099,7 +1153,8 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr);
10991153bool isReductionNotation (IndexStmt, std::string* reason=nullptr );
11001154
11011155// / Check whether the statement is in the concrete index notation dialect.
1102- // / This means every index variable has a forall node, there are no reduction
1156+ // / This means every index variable has a forall node, each index variable used
1157+ // / for computation is under a forall node for that variable, there are no reduction
11031158// / nodes, and that every reduction variable use is nested inside a compound
11041159// / assignment statement. You can optionally pass in a pointer to a string
11051160// / that the reason why it is not concrete notation is printed to.
@@ -1121,7 +1176,12 @@ std::vector<TensorVar> getResults(IndexStmt stmt);
11211176// / Returns the input tensors to the index statement, in the order they appear.
11221177std::vector<TensorVar> getArguments (IndexStmt stmt);
11231178
1124- // / Returns the temporaries in the index statement, in the order they appear.
1179+ // / Returns true iff all of the loops over free variables come before all of the loops over
1180+ // / reduction variables. Therefore, this returns true if the reduction controlled by the loops
1181+ // / does not a scatter.
1182+ bool allForFreeLoopsBeforeAllReductionLoops (IndexStmt stmt);
1183+
1184+ // / Returns the temporaries in the index statement, in the order they appear.
11251185std::vector<TensorVar> getTemporaries (IndexStmt stmt);
11261186
11271187// / Returns the attribute query results in the index statement, in the order
@@ -1173,7 +1233,15 @@ IndexExpr zero(IndexExpr, const std::set<Access>& zeroed);
11731233// / zero and then propagating and removing zeroes.
11741234IndexStmt zero (IndexStmt, const std::set<Access>& zeroed);
11751235
1176- // / Create an `other` tensor with the given name and format,
1236+ // / Infers the fill value of the input expression by applying properties if possible. If unable
1237+ // / to successfully infer the fill value of the result, returns the empty IndexExpr
1238+ IndexExpr inferFill (IndexExpr);
1239+
1240+ // / Returns true if there are no forall nodes in the indexStmt. Used to check
1241+ // / if the last loop is being lowered.
1242+ bool hasNoForAlls (IndexStmt);
1243+
1244+ // / Create an `other` tensor with the given name and format,
11771245// / and return tensor(indexVars) = other(indexVars) if otherIsOnRight,
11781246// / and otherwise returns other(indexVars) = tensor(indexVars).
11791247IndexStmt generatePackStmt (TensorVar tensor,
0 commit comments