Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 6577d4f

Browse files
Merge pull request #355 from facebookresearch/pr/copies_under
insertCopiesUnder: drop iteration over tensor dimensions of size 1 up front
2 parents ea23712 + d984f4e commit 6577d4f

File tree

5 files changed

+46
-31
lines changed

5 files changed

+46
-31
lines changed

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,7 @@ void emitMappedTensorAccess(
592592
return;
593593
}
594594

595-
auto tensorId =
596-
context.scop().promotedDecls().at(promotionInfo.groupId).tensorId;
595+
auto tensorId = context.scop().promotedDecl(promotionInfo.groupId).tensorId;
597596

598597
// Here and below in comments: D = domain, O = original tensor, P = promoted
599598
// tensor, S = partial schedule, A = AST loops;

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -99,41 +99,18 @@ void mapCopiesToThreads(MappedScop& mscop, bool unroll) {
9999

100100
// Map band dimensions to threads, in inverse order since the last member
101101
// iterates over the last subscript and is likely to result in coalescing.
102-
// Step over band members that iterate over size-1 arrays subscripts as
103-
// they would have been executed by a single thread.
104102
// If not all available thread ids are used, fix remaining to 1 thread.
105-
auto filter = node->elemAs<ScheduleTreeElemFilter>()->filter_;
106-
auto filterSets = isl::UnionAsVector<isl::union_set>(filter);
107-
size_t t = 0;
108-
for (int i = band->nMember() - 1;
109-
i >= 0 && t < mscop.numThreads.view.size();
110-
--i) {
111-
auto skip = std::all_of(
112-
filterSets.begin(), filterSets.end(), [&mscop, i](isl::set s) {
113-
auto groupId =
114-
s.get_space().unwrap().get_tuple_id(isl::dim_type::out);
115-
if (mscop.scop().promotedDecls().count(groupId) != 1) {
116-
std::stringstream ss;
117-
ss << "promoted group " << groupId << " has no declaration";
118-
throw promotion::PromotionLogicError(ss.str());
119-
}
120-
auto decl = mscop.scop().promotedDecls().at(groupId);
121-
return static_cast<size_t>(i) >= decl.sizes.size() ||
122-
decl.sizes[i] == 1;
123-
});
124-
if (skip) {
125-
continue;
126-
}
127-
103+
auto nToMap = std::min(band->nMember(), mscop.numThreads.view.size());
104+
for (size_t t = 0; t < nToMap; ++t) {
105+
auto pos = band->nMember() - 1 - t;
128106
mapToParameterWithExtent(
129107
root,
130108
bandNode,
131-
i,
109+
pos,
132110
mapping::ThreadId::makeId(t),
133111
mscop.numThreads.view[t]);
134-
++t;
135112
}
136-
mscop.mapRemaining<mapping::ThreadId>(bandNode, t);
113+
mscop.mapRemaining<mapping::ThreadId>(bandNode, nToMap);
137114

138115
// Unroll if requested.
139116
if (unroll) {

tc/core/polyhedral/memory_promotion.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,30 @@ isl::set tensorElementsSet(const Scop& scop, isl::id tensorId) {
422422
}
423423
return tensorElements;
424424
}
425+
426+
/*
427+
* "schedule" iterates over the elements of the tensor described by "decl".
428+
* Remove the schedule dimensions that correspond to tensor dimensions
429+
* of size 1.
430+
* Note that this function drops the name of the target space of "schedule",
431+
* but this space is irrelevant for the caller.
432+
*/
433+
isl::multi_aff dropDummyTensorDimensions(
434+
isl::multi_aff schedule,
435+
const Scop::PromotedDecl& decl) {
436+
auto list = schedule.get_aff_list();
437+
auto space = schedule.get_space().domain();
438+
439+
auto n = list.n();
440+
for (int i = n - 1; i >= 0; --i) {
441+
if (decl.sizes[i] == 1) {
442+
list = list.drop(i, 1);
443+
}
444+
}
445+
446+
space = space.from_domain().add_dims(isl::dim_type::out, list.n());
447+
return isl::multi_aff(space, list);
448+
}
425449
} // namespace
426450

427451
ScheduleTree* insertCopiesUnder(
@@ -449,6 +473,9 @@ ScheduleTree* insertCopiesUnder(
449473
isl::multi_aff::identity(promotionSpace.range().map_from_set());
450474
identityCopySchedule =
451475
identityCopySchedule.pullback(isl::multi_aff::range_map(promotionSpace));
476+
// Only iterate over significant tensor dimensions.
477+
auto decl = scop.promotedDecl(groupId);
478+
identityCopySchedule = dropDummyTensorDimensions(identityCopySchedule, decl);
452479
auto readSchedule = isl::multi_union_pw_aff(
453480
identityCopySchedule.set_tuple_id(isl::dim_type::in, readId));
454481
auto writeSchedule = isl::multi_union_pw_aff(

tc/core/polyhedral/scop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,12 @@ void Scop::promoteGroup(
201201
}
202202

203203
auto groupId = nextGroupIdForTensor(tensorId);
204-
insertCopiesUnder(*this, tree, *gr, tensorId, groupId);
205204
auto sizes = gr->approximationSizes();
206205
if (sizes.size() > 0 && forceLastExtentOdd && (sizes.back() % 2) == 0) {
207206
sizes.back() += 1;
208207
}
209208
promotedDecls_[groupId] = PromotedDecl{tensorId, sizes, kind};
209+
insertCopiesUnder(*this, tree, *gr, tensorId, groupId);
210210

211211
// FIXME: we can now store a unique pointer...
212212
auto group = std::shared_ptr<TensorReferenceGroup>(std::move(gr));

tc/core/polyhedral/scop.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <functional>
1919
#include <memory>
20+
#include <sstream>
2021
#include <string>
2122
#include <unordered_map>
2223
#include <vector>
@@ -313,6 +314,17 @@ struct Scop {
313314
return promotedDecls_;
314315
}
315316

317+
// Return the promoted declaration information associated to
318+
// the given identifier of a promoted tensor reference group.
319+
const PromotedDecl& promotedDecl(isl::id groupId) const {
320+
if (promotedDecls().count(groupId) != 1) {
321+
std::stringstream ss;
322+
ss << "promoted group " << groupId << " has no declaration";
323+
throw std::logic_error(ss.str());
324+
}
325+
return promotedDecls().at(groupId);
326+
}
327+
316328
const std::vector<std::pair<isl::union_set, PromotionInfo>>&
317329
activePromotions() const {
318330
return activePromotions_;

0 commit comments

Comments
 (0)