@@ -549,6 +549,10 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
549549 // / Takes any index notation and concretizes unknowns to make it concrete notation
550550 IndexStmt concretize () const ;
551551
552+ // / Takes any index notation and concretizes unknowns to make it concrete notation
553+ // / given a Provenance Graph of indexVars
554+ IndexStmt concretizeScheduled (ProvenanceGraph provGraph, std::vector<IndexVar> forallIndexVarList) const ;
555+
552556 // / The \code{split} transformation splits (strip-mines) an index
553557 // / variable into two nested index variables, where the size of the
554558 // / inner index variable is constant. The size of the outer index
@@ -681,6 +685,12 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
681685 // / reorder computations to increase locality
682686 IndexStmt precompute (IndexExpr expr, IndexVar i, IndexVar iw, TensorVar workspace) const ;
683687
688+ // / The precompute transformation is described in kjolstad2019
689+ // / allows us to leverage scratchpad memories and
690+ // / reorder computations to increase locality
691+ IndexStmt precompute (IndexExpr expr, std::vector<IndexVar> i_vars,
692+ std::vector<IndexVar> iw_vars, TensorVar workspace) const ;
693+
684694 // / bound specifies a compile-time constraint on an index variable's
685695 // / iteration space that allows knowledge of the
686696 // / size or structured sparsity pattern of the inputs to be
@@ -1119,6 +1129,10 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr);
11191129// / notation is printed to.
11201130bool isReductionNotation (IndexStmt, std::string* reason=nullptr );
11211131
1132+ // / Check whether the statement is in the reduction index notation dialect
1133+ // / given a schedule described by the Provenance Graph
1134+ bool isReductionNotationScheduled (IndexStmt, ProvenanceGraph, std::string* reason=nullptr );
1135+
11221136// / Check whether the statement is in the concrete index notation dialect.
11231137// / This means every index variable has a forall node, there are no reduction
11241138// / nodes, and that every reduction variable use is nested inside a compound
@@ -1136,6 +1150,18 @@ IndexStmt makeReductionNotation(IndexStmt);
11361150// / as needed.
11371151IndexStmt makeConcreteNotation (IndexStmt);
11381152
1153+
1154+ // / Convert einsum notation to reduction notation, by applying Einstein's
1155+ // / summation convention to sum non-free/reduction variables over their term
1156+ // / taking into account a schedule given by the Provenance Graph.
1157+ Assignment makeReductionNotationScheduled (Assignment, ProvenanceGraph);
1158+ IndexStmt makeReductionNotationScheduled (IndexStmt, ProvenanceGraph);
1159+
1160+ // / Convert reduction notation to concrete notation, by inserting forall nodes,
1161+ // / replacing reduction nodes by compound assignments, and inserting temporaries
1162+ // / as needed while taking into account a schedule given by the Provenance Graph.
1163+ IndexStmt makeConcreteNotationScheduled (IndexStmt, ProvenanceGraph, std::vector<IndexVar> forallIndexVars);
1164+
11391165// / Returns the results of the index statement, in the order they appear.
11401166std::vector<TensorVar> getResults (IndexStmt stmt);
11411167
0 commit comments