99#include < set>
1010#include < map>
1111#include < utility>
12+ #include < functional>
1213
14+ #include " taco/util/name_generator.h"
1315#include " taco/format.h"
1416#include " taco/error.h"
1517#include " taco/util/intrusive_ptr.h"
2123#include " taco/index_notation/index_notation_nodes_abstract.h"
2224#include " taco/ir_tags.h"
2325#include " taco/index_notation/provenance_graph.h"
26+ #include " taco/index_notation/properties.h"
2427
2528namespace taco {
2629
@@ -39,6 +42,8 @@ class IndexExpr;
3942class Assignment ;
4043class Access ;
4144
45+ class IterationAlgebra ;
46+
4247struct AccessNode ;
4348struct IndexVarIterationModifier ;
4449struct LiteralNode ;
@@ -49,8 +54,10 @@ struct SubNode;
4954struct MulNode ;
5055struct DivNode ;
5156struct CastNode ;
57+ struct CallNode ;
5258struct CallIntrinsicNode ;
5359struct ReductionNode ;
60+ struct IndexVarNode ;
5461
5562struct AssignmentNode ;
5663struct YieldNode ;
@@ -231,7 +238,7 @@ class Access : public IndexExpr {
231238 Access () = default ;
232239 Access (const Access&) = default ;
233240 Access (const AccessNode*);
234- Access (const TensorVar& tensorVar, const std::vector<IndexVar>& indices={},
241+ Access (const TensorVar& tensorVar, const std::vector<IndexVar>& indices={},
235242 const std::map<int , std::shared_ptr<IndexVarIterationModifier>>& modifiers={},
236243 bool isAccessingStructure=false );
237244
@@ -296,6 +303,11 @@ class Access : public IndexExpr {
296303 Assignment operator +=(const IndexExpr&);
297304
298305 typedef AccessNode Node;
306+
307+ // Equality and comparison are overridden on Access to perform a deep
308+ // comparison of the access rather than a pointer check.
309+ friend bool operator ==(const Access& a, const Access& b);
310+ friend bool operator <(const Access& a, const Access &b);
299311};
300312
301313
@@ -323,11 +335,14 @@ class Literal : public IndexExpr {
323335 Literal (std::complex <float >);
324336 Literal (std::complex <double >);
325337
326- static IndexExpr zero (Datatype);
338+ static Literal zero (Datatype);
327339
328340 // / Returns the literal value.
329341 template <typename T> T getVal () const ;
330342
343+ // / Returns an untyped pointer to the literal value
344+ void * getValPtr ();
345+
331346 typedef LiteralNode Node;
332347};
333348
@@ -447,6 +462,26 @@ class Cast : public IndexExpr {
447462 typedef CastNode Node;
448463};
449464
465+ // / A call to an operator
466+ class Call : public IndexExpr {
467+ public:
468+ Call () = default ;
469+ Call (const CallNode*);
470+ Call (const CallNode*, std::string name);
471+
472+ const std::vector<IndexExpr>& getArgs () const ;
473+ const std::function<ir::Expr(const std::vector<ir::Expr>&)> getFunc () const ;
474+ const IterationAlgebra& getAlgebra () const ;
475+ const std::vector<Property>& getProperties () const ;
476+ const std::string getName () const ;
477+ const std::map<std::vector<int >, std::function<ir::Expr(const std::vector<ir::Expr>&)>> getDefs () const ;
478+ const std::vector<int >& getDefinedArgs () const ;
479+
480+ typedef CallNode Node;
481+
482+ private:
483+ std::string name;
484+ };
450485
451486// / A call to an intrinsic.
452487// / ```
@@ -467,6 +502,8 @@ class CallIntrinsic : public IndexExpr {
467502 typedef CallIntrinsicNode Node;
468503};
469504
505+ std::ostream& operator <<(std::ostream&, const IndexVar&);
506+
470507// / Create calls to various intrinsics.
471508IndexExpr mod (IndexExpr, IndexExpr);
472509IndexExpr abs (IndexExpr);
@@ -982,17 +1019,27 @@ class IndexSetVar : public util::Comparable<IndexSetVar>, public IndexVarInterfa
9821019
9831020// / Index variables are used to index into tensors in index expressions, and
9841021// / they represent iteration over the tensor modes they index into.
985- class IndexVar : public util ::Comparable<IndexVar>, public IndexVarInterface {
1022+ class IndexVar : public IndexExpr , public IndexVarInterface {
1023+
9861024public:
9871025 IndexVar ();
9881026 ~IndexVar () = default ;
9891027 IndexVar (const std::string& name);
1028+ IndexVar (const std::string& name, const Datatype& type);
1029+ IndexVar (const IndexVarNode *);
9901030
9911031 // / Returns the name of the index variable.
9921032 std::string getName () const ;
9931033
1034+ // Need these to overshadow the comparisons in for the IndexExpr instrusive pointer
9941035 friend bool operator ==(const IndexVar&, const IndexVar&);
9951036 friend bool operator <(const IndexVar&, const IndexVar&);
1037+ friend bool operator !=(const IndexVar&, const IndexVar&);
1038+ friend bool operator >=(const IndexVar&, const IndexVar&);
1039+ friend bool operator <=(const IndexVar&, const IndexVar&);
1040+ friend bool operator >(const IndexVar&, const IndexVar&);
1041+
1042+ typedef IndexVarNode Node;
9961043
9971044 // / Indexing into an IndexVar returns a window into it.
9981045 WindowedIndexVar operator ()(int lo, int hi, int stride = 1 );
@@ -1049,11 +1096,12 @@ SuchThat suchthat(IndexStmt stmt, std::vector<IndexVarRel> predicate);
10491096class TensorVar : public util ::Comparable<TensorVar> {
10501097public:
10511098 TensorVar ();
1052- TensorVar (const Type& type);
1053- TensorVar (const std::string& name, const Type& type);
1054- TensorVar (const Type& type, const Format& format);
1055- TensorVar (const std::string& name, const Type& type, const Format& format);
1056- TensorVar (const int &id, const std::string& name, const Type& type, const Format& format);
1099+ TensorVar (const Type& type, const Literal& fill = Literal());
1100+ TensorVar (const std::string& name, const Type& type, const Literal& fill = Literal());
1101+ TensorVar (const Type& type, const Format& format, const Literal& fill = Literal());
1102+ TensorVar (const std::string& name, const Type& type, const Format& format, const Literal& fill = Literal());
1103+ TensorVar (const int &id, const std::string& name, const Type& type, const Format& format,
1104+ const Literal& fill = Literal());
10571105
10581106 // / Returns the ID of the tensor variable.
10591107 int getId () const ;
@@ -1074,6 +1122,12 @@ class TensorVar : public util::Comparable<TensorVar> {
10741122 // / and execute it's expression.
10751123 const Schedule& getSchedule () const ;
10761124
1125+ // / Gets the fill value of the tensor variable. May be left undefined.
1126+ const Literal& getFill () const ;
1127+
1128+ // / Set the fill value of the tensor variable
1129+ void setFill (const Literal& fill);
1130+
10771131 // / Set the name of the tensor variable.
10781132 void setName (std::string name);
10791133
@@ -1134,7 +1188,8 @@ bool isReductionNotation(IndexStmt, std::string* reason=nullptr);
11341188bool isReductionNotationScheduled (IndexStmt, ProvenanceGraph, std::string* reason=nullptr );
11351189
11361190// / Check whether the statement is in the concrete index notation dialect.
1137- // / This means every index variable has a forall node, there are no reduction
1191+ // / This means every index variable has a forall node, each index variable used
1192+ // / for computation is under a forall node for that variable, there are no reduction
11381193// / nodes, and that every reduction variable use is nested inside a compound
11391194// / assignment statement. You can optionally pass in a pointer to a string
11401195// / that the reason why it is not concrete notation is printed to.
@@ -1168,7 +1223,12 @@ std::vector<TensorVar> getResults(IndexStmt stmt);
11681223// / Returns the input tensors to the index statement, in the order they appear.
11691224std::vector<TensorVar> getArguments (IndexStmt stmt);
11701225
1171- // / Returns the temporaries in the index statement, in the order they appear.
1226+ // / Returns true iff all of the loops over free variables come before all of the loops over
1227+ // / reduction variables. Therefore, this returns true if the reduction controlled by the loops
1228+ // / does not a scatter.
1229+ bool allForFreeLoopsBeforeAllReductionLoops (IndexStmt stmt);
1230+
1231+ // / Returns the temporaries in the index statement, in the order they appear.
11721232std::vector<TensorVar> getTemporaries (IndexStmt stmt);
11731233
11741234// / Returns the attribute query results in the index statement, in the order
@@ -1220,7 +1280,15 @@ IndexExpr zero(IndexExpr, const std::set<Access>& zeroed);
12201280// / zero and then propagating and removing zeroes.
12211281IndexStmt zero (IndexStmt, const std::set<Access>& zeroed);
12221282
1223- // / Create an `other` tensor with the given name and format,
1283+ // / Infers the fill value of the input expression by applying properties if possible. If unable
1284+ // / to successfully infer the fill value of the result, returns the empty IndexExpr
1285+ IndexExpr inferFill (IndexExpr);
1286+
1287+ // / Returns true if there are no forall nodes in the indexStmt. Used to check
1288+ // / if the last loop is being lowered.
1289+ bool hasNoForAlls (IndexStmt);
1290+
1291+ // / Create an `other` tensor with the given name and format,
12241292// / and return tensor(indexVars) = other(indexVars) if otherIsOnRight,
12251293// / and otherwise returns other(indexVars) = tensor(indexVars).
12261294IndexStmt generatePackStmt (TensorVar tensor,
0 commit comments