Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit d7442e5

Browse files
Merge pull request #570 from facebookresearch/pr/tuple
use tuple_id related convenience functions
2 parents bb59391 + 33f51a0 commit d7442e5

File tree

3 files changed

+18
-20
lines changed

3 files changed

+18
-20
lines changed

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ void emitRegisterAccess(
438438
void emitGlobalAccess(
439439
isl::multi_pw_aff access,
440440
const CodegenStatementContext& context) {
441-
LdgWrapper ldgWrapper(context, access.get_tuple_id(isl::dim_type::out));
441+
LdgWrapper ldgWrapper(context, access.get_range_tuple_id());
442442
emitAccess(access, context);
443443
}
444444
} // namespace
@@ -674,7 +674,7 @@ void emitMappedTensorAccess(
674674
auto access =
675675
makeMultiAffAccess(tensorId, subscripts, context); // MA :: D -> O
676676
auto promotion = promotionInfo.group->promotion(); // MA :: [S -> O] -> P
677-
promotion = promotion.set_tuple_id(isl::dim_type::out, promotionInfo.groupId);
677+
promotion = promotion.set_range_tuple_id(promotionInfo.groupId);
678678
auto iteratorMap = context.iteratorMap(); // PMA :: A -> D
679679
auto schedule =
680680
isl::map::from_union_map(promotionInfo.outerSchedule.intersect_domain(

tc/core/polyhedral/cuda/codegen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ struct CodegenStatementContext : CodegenContext {
114114
return this->nodeInfoMap.at(astNodeId).build;
115115
}
116116
isl::id statementId() const {
117-
return this->iteratorMap().get_tuple_id(isl::dim_type::out);
117+
return this->iteratorMap().get_range_tuple_id();
118118
}
119119
isl::set domain() const {
120120
return isl::map::from(this->iteratorMap()).range();

tc/core/polyhedral/memory_promotion.cc

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ std::unique_ptr<TensorReferenceGroup> TensorReferenceGroup::makeSingleton(
8787
isl::map scopedAccess,
8888
AccessType type) {
8989
auto ref = std::unique_ptr<TensorReference>(new TensorReference);
90-
auto refId = scopedAccess.get_space().domain().unwrap().get_tuple_id(
91-
isl::dim_type::out);
90+
auto refId =
91+
scopedAccess.get_space().domain().unwrap().get_map_range_tuple_id();
9292
scopedAccess = scopedAccess.domain_factor_domain();
9393
ref->originalAccess = originalAccess.domain_factor_domain();
9494
ref->scopedAccess = scopedAccess;
@@ -306,7 +306,7 @@ void addSingletonReferenceGroups(
306306
continue;
307307
}
308308

309-
auto tensorId = a.get_tuple_id(isl::dim_type::out);
309+
auto tensorId = a.get_range_tuple_id();
310310
if (unapproximatable.count(tensorId) != 0) {
311311
continue;
312312
}
@@ -474,21 +474,20 @@ ScheduleTree* insertCopiesUnder(
474474
// Take the set of all tensor elements.
475475
auto tensorElements = tensorElementsSet(scop, tensorId);
476476

477-
auto promotion =
478-
isl::map(group.promotion()).set_tuple_id(isl::dim_type::out, groupId);
477+
auto promotion = isl::map(group.promotion()).set_range_tuple_id(groupId);
479478
auto promotionSpace = promotion.get_space();
480479

481480
auto identityCopySchedule =
482481
isl::multi_aff::identity(promotionSpace.range().map_from_set());
483-
identityCopySchedule =
484-
identityCopySchedule.pullback(isl::multi_aff::range_map(promotionSpace));
485482
// Only iterate over significant tensor dimensions.
486483
auto decl = scop.promotedDecl(groupId);
487484
identityCopySchedule = dropDummyTensorDimensions(identityCopySchedule, decl);
488-
auto readSchedule = isl::multi_union_pw_aff(
489-
identityCopySchedule.set_tuple_id(isl::dim_type::in, readId));
490-
auto writeSchedule = isl::multi_union_pw_aff(
491-
identityCopySchedule.set_tuple_id(isl::dim_type::in, writeId));
485+
auto readSpace = promotionSpace.wrap().set_set_tuple_id(readId);
486+
auto writeSpace = promotionSpace.wrap().set_set_tuple_id(writeId);
487+
auto readSchedule = isl::multi_union_pw_aff(identityCopySchedule.pullback(
488+
isl::multi_aff::wrapped_range_map(readSpace)));
489+
auto writeSchedule = isl::multi_union_pw_aff(identityCopySchedule.pullback(
490+
isl::multi_aff::wrapped_range_map(writeSpace)));
492491

493492
auto readBandNode = ScheduleTree::makeBand(readSchedule);
494493
auto writeBandNode = ScheduleTree::makeBand(writeSchedule);
@@ -508,18 +507,17 @@ ScheduleTree* insertCopiesUnder(
508507
auto promotedFootprint = group.promotedFootprint().set_tuple_id(groupId);
509508
auto scheduleUniverse =
510509
isl::set::universe(promotionSpace.domain().unwrap().domain());
511-
auto arrayId =
512-
promotionSpace.domain().unwrap().get_tuple_id(isl::dim_type::out);
510+
auto arrayId = promotionSpace.domain().unwrap().get_map_range_tuple_id();
513511
auto approximatedRead =
514512
group.approximateScopedAccesses().intersect_range(tensorElements).wrap();
515513
approximatedRead = approximatedRead.product(promotedFootprint);
516-
auto readExtension = extension.intersect_range(approximatedRead)
517-
.set_tuple_id(isl::dim_type::out, readId);
514+
auto readExtension =
515+
extension.intersect_range(approximatedRead).set_range_tuple_id(readId);
518516
auto writtenElements =
519517
group.scopedWrites().intersect_range(tensorElements).wrap();
520518
writtenElements = writtenElements.product(promotedFootprint);
521-
auto writeExtension = extension.intersect_range(writtenElements)
522-
.set_tuple_id(isl::dim_type::out, writeId);
519+
auto writeExtension =
520+
extension.intersect_range(writtenElements).set_range_tuple_id(writeId);
523521

524522
auto readFilterNode = ScheduleTree::makeFilter(
525523
isl::set::universe(readExtension.get_space().range()),

0 commit comments

Comments
 (0)