Skip to content

Commit fcb7d76

Browse files
committed
RequirementMachine: Adding shape abstractions
Adding abstractions to check terms for shape symbol and remove the shape symbol from the end of the sequence of symbols, rather than manually manipulating the end() sequence externally.
1 parent e50f3db commit fcb7d76

File tree

6 files changed

+39
-14
lines changed

6 files changed

+39
-14
lines changed

lib/AST/RequirementMachine/GenericSignatureQueries.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -755,16 +755,16 @@ RequirementMachine::getReducedShapeTerm(Type type) const {
755755
// Get the term T', which is the reduced shape of T.
756756
if (term.size() != 2 ||
757757
term[0].getKind() != Symbol::Kind::GenericParam ||
758-
term[1].getKind() != Symbol::Kind::Shape) {
758+
!term.hasShape()) {
759759
ABORT([&](auto &out) {
760760
out << "Invalid reduced shape\n";
761761
out << "Type: " << type << "\n";
762762
out << "Term: " << term;
763763
});
764764
}
765765

766-
MutableTerm reducedTerm(term.begin(), term.end() - 1);
767-
return reducedTerm;
766+
term.removeShape();
767+
return term;
768768
}
769769

770770
Type RequirementMachine::getReducedShape(Type type,

lib/AST/RequirementMachine/InterfaceType.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ RewriteContext::getRelativeTermForType(CanType typeWitness,
424424
// Get the substitution S corresponding to τ_0_n.
425425
unsigned index = getGenericParamIndex(typeWitness->getRootGenericParam());
426426
result = MutableTerm(substitutions[index]);
427-
ASSERT(result.back().getKind() != Symbol::Kind::Shape);
427+
ASSERT(!result.hasShape());
428428

429429
// If the substitution is a term consisting of a single protocol symbol
430430
// [P], save P for later.
@@ -485,7 +485,7 @@ Type PropertyMap::getTypeFromSubstitutionSchema(
485485
auto substitution = substitutions[index];
486486

487487
bool isShapePosition = (pos == TypePosition::Shape);
488-
bool isShapeTerm = (substitution.back() == Symbol::forShape(Context));
488+
bool isShapeTerm = substitution.hasShape();
489489
if (isShapePosition != isShapeTerm) {
490490
ABORT([&](auto &out) {
491491
out << "Shape vs. type mixup\n\n";
@@ -504,8 +504,8 @@ Type PropertyMap::getTypeFromSubstitutionSchema(
504504
// Undo the thing where the count type of a PackExpansionType
505505
// becomes a shape term.
506506
if (isShapeTerm) {
507-
MutableTerm mutTerm(substitution.begin(), substitution.end() - 1);
508-
substitution = Term::get(mutTerm, Context);
507+
MutableTerm noShape = substitution.termWithoutShape();
508+
substitution = Term::get(noShape, Context);
509509
}
510510

511511
// Prepend the prefix of the lookup key to the substitution.

lib/AST/RequirementMachine/RequirementBuilder.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,10 @@ void RequirementBuilder::addRequirementRules(ArrayRef<unsigned> rules) {
257257

258258
ASSERT(rule.getLHS().back().getKind() != Symbol::Kind::Protocol);
259259

260-
if (constraintTerm.back().getKind() == Symbol::Kind::Shape) {
261-
ASSERT(rule.getRHS().back().getKind() == Symbol::Kind::Shape);
260+
if (constraintTerm.hasShape()) {
261+
ASSERT(rule.getRHS().hasShape());
262262
// Strip off the shape symbol from the constraint term.
263-
constraintTerm = MutableTerm(constraintTerm.begin(),
264-
constraintTerm.end() - 1);
263+
constraintTerm.removeShape();
265264
}
266265

267266
if (constraintTerm.front().getKind() == Symbol::Kind::PackElement) {
@@ -332,10 +331,10 @@ void RequirementBuilder::processConnectedComponents() {
332331
for (auto &pair : Components) {
333332
MutableTerm subjectTerm(pair.first);
334333
RequirementKind kind;
335-
if (subjectTerm.back().getKind() == Symbol::Kind::Shape) {
334+
if (subjectTerm.hasShape()) {
336335
kind = RequirementKind::SameShape;
337336
// Strip off the shape symbol from the subject term.
338-
subjectTerm = MutableTerm(subjectTerm.begin(), subjectTerm.end() - 1);
337+
subjectTerm.removeShape();
339338
} else {
340339
kind = RequirementKind::SameType;
341340
if (subjectTerm.front().getKind() == Symbol::Kind::PackElement) {

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
634634

635635
if (rhs.size() == 1 && rhs[0].getKind() == Symbol::Kind::Shape) {
636636
// We can have a rule like T.[shape] => [shape].
637-
ASSERT_RULE(lhs.back().getKind() == Symbol::Kind::Shape);
637+
ASSERT_RULE(lhs.hasShape());
638638
} else {
639639
// Otherwise, LHS and RHS must have the same domain.
640640
auto lhsDomain = lhs.getRootProtocol();

lib/AST/RequirementMachine/Term.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ Symbol Term::back() const {
6868
return Ptr->getElements().back();
6969
}
7070

71+
bool Term::hasShape() const {
72+
return back().getKind() == Symbol::Kind::Shape;
73+
}
74+
75+
MutableTerm Term::termWithoutShape() const {
76+
if (hasShape())
77+
return MutableTerm(begin(), end() - 1);
78+
else
79+
return MutableTerm(begin(), end());
80+
}
81+
7182
Symbol Term::operator[](size_t index) const {
7283
return Ptr->getElements()[index];
7384
}
@@ -224,6 +235,15 @@ std::optional<int> MutableTerm::compare(const MutableTerm &other,
224235
return compareImpl(begin(), end(), other.begin(), other.end(), ctx);
225236
}
226237

238+
bool MutableTerm::hasShape() const {
239+
return back().getKind() == Symbol::Kind::Shape;
240+
}
241+
242+
void MutableTerm::removeShape() {
243+
if (hasShape())
244+
Symbols.pop_back();
245+
}
246+
227247
/// Replace the subterm in the range [from,to) of this term with \p rhs.
228248
void MutableTerm::rewriteSubTerm(Symbol *from, Symbol *to, Term rhs) {
229249
auto oldSize = size();

lib/AST/RequirementMachine/Term.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ class Term final {
5959

6060
Symbol back() const;
6161

62+
bool hasShape() const;
63+
MutableTerm termWithoutShape() const;
64+
6265
Symbol operator[](size_t index) const;
6366

6467
/// Returns an opaque pointer that uniquely identifies this term.
@@ -184,6 +187,9 @@ class MutableTerm final {
184187
return Symbols.back();
185188
}
186189

190+
bool hasShape() const;
191+
void removeShape();
192+
187193
Symbol operator[](size_t index) const {
188194
return Symbols[index];
189195
}

0 commit comments

Comments
 (0)