@@ -306,8 +306,11 @@ struct Scop {
306306 void promoteEverythingAt (std::vector<size_t > pos);
307307
308308 struct PromotedDecl {
309+ enum class Kind { SharedMem, Register };
310+
309311 isl::id tensorId;
310312 std::vector<size_t > sizes;
313+ Kind kind;
311314 };
312315
313316 struct PromotionInfo {
@@ -321,9 +324,8 @@ struct Scop {
321324 return promotedDecls_;
322325 }
323326
324- const std::
325- unordered_map<isl::id, std::vector<PromotionInfo>, isl::IslIdIslHash>&
326- activePromotions () const {
327+ const std::vector<std::pair<isl::union_set, PromotionInfo>>&
328+ activePromotions () const {
327329 return activePromotions_;
328330 }
329331
@@ -356,7 +358,8 @@ struct Scop {
356358 // Assumes such argument exists.
357359 const Halide::OutputImageParam& findArgument (isl::id id) const ;
358360
359- // Promote a tensor reference group to shared memory, inserting the copy
361+ // Promote a tensor reference group to a storage of a given "kind",
362+ // inserting the copy
360363 // statements below the given node. Inserts an Extension node below the give
361364 // node, unless there is already another Extension node which introduces
362365 // copies. The Extension node has a unique Sequence child, whose children
@@ -368,11 +371,11 @@ struct Scop {
368371 // If "forceLastExtentOdd" is set, the last extent in the declaration is
369372 // incremented if it is even. This serves as a simple heuristic to reduce
370373 // shared memory bank conflicts.
371- void promoteGroupToShared (
374+ void promoteGroup (
375+ PromotedDecl::Kind kind,
372376 isl::id tensorId,
373377 std::unique_ptr<TensorReferenceGroup>&& gr,
374378 detail::ScheduleTree* tree,
375- const std::unordered_set<isl::id, isl::IslIdIslHash>& activeStmts,
376379 isl::union_map schedule,
377380 bool forceLastExtentOdd = false );
378381
@@ -463,9 +466,10 @@ struct Scop {
463466 std::unordered_map<isl::id, size_t , isl::IslIdIslHash> groupCounts_;
464467 // groupId -> (tensorId, groupSizes)
465468 std::unordered_map<isl::id, PromotedDecl, isl::IslIdIslHash> promotedDecls_;
466- // stmtId -> (group, partial schedule, groupId)
467- std::unordered_map<isl::id, std::vector<PromotionInfo>, isl::IslIdIslHash>
468- activePromotions_;
469+ // (domain, group, partial schedule, groupId)
470+ // Note that domain is a non-unique key, i.e. multiple groups can be listed
471+ // for the same domain, or for partially intersecting domains.
472+ std::vector<std::pair<isl::union_set, PromotionInfo>> activePromotions_;
469473};
470474
471475std::ostream& operator <<(std::ostream& os, const Scop&);
0 commit comments