3333#include " swift/AST/TypeMatcher.h"
3434#include " swift/AST/TypeRepr.h"
3535#include " llvm/ADT/SmallVector.h"
36+ #include " llvm/ADT/SetVector.h"
37+ #include " RequirementMachine.h"
3638#include " RewriteContext.h"
3739#include " RewriteSystem.h"
3840#include " Symbol.h"
@@ -1013,7 +1015,7 @@ ArrayRef<ProtocolDecl *>
10131015ProtocolDependenciesRequest::evaluate (Evaluator &evaluator,
10141016 ProtocolDecl *proto) const {
10151017 auto &ctx = proto->getASTContext ();
1016- SmallVector <ProtocolDecl *, 4 > result;
1018+ SmallSetVector <ProtocolDecl *, 4 > result;
10171019
10181020 // If we have a serialized requirement signature, deserialize it and
10191021 // look at conformance requirements.
@@ -1025,7 +1027,7 @@ ProtocolDependenciesRequest::evaluate(Evaluator &evaluator,
10251027 == RequirementMachineMode::Disabled)) {
10261028 for (auto req : proto->getRequirementSignature ().getRequirements ()) {
10271029 if (req.getKind () == RequirementKind::Conformance) {
1028- result.push_back (req.getProtocolDecl ());
1030+ result.insert (req.getProtocolDecl ());
10291031 }
10301032 }
10311033
@@ -1037,7 +1039,7 @@ ProtocolDependenciesRequest::evaluate(Evaluator &evaluator,
10371039 // signature. Look at the structural requirements instead.
10381040 for (auto req : proto->getStructuralRequirements ()) {
10391041 if (req.req .getKind () == RequirementKind::Conformance)
1040- result.push_back (req.req .getProtocolDecl ());
1042+ result.insert (req.req .getProtocolDecl ());
10411043 }
10421044
10431045 return ctx.AllocateCopy (result);
@@ -1047,11 +1049,17 @@ ProtocolDependenciesRequest::evaluate(Evaluator &evaluator,
10471049// Building rewrite rules from desugared requirements.
10481050//
10491051
1050- void RuleBuilder::addRequirements (ArrayRef<Requirement> requirements) {
1052+ // / For building a rewrite system for a generic signature from canonical
1053+ // / requirements.
1054+ void RuleBuilder::initWithGenericSignatureRequirements (
1055+ ArrayRef<Requirement> requirements) {
1056+ assert (!Initialized);
1057+ Initialized = 1 ;
1058+
10511059 // Collect all protocols transitively referenced from these requirements.
10521060 for (auto req : requirements) {
10531061 if (req.getKind () == RequirementKind::Conformance) {
1054- addProtocol (req.getProtocolDecl (), /* initialComponent= */ false );
1062+ addReferencedProtocol (req.getProtocolDecl ());
10551063 }
10561064 }
10571065
@@ -1062,11 +1070,17 @@ void RuleBuilder::addRequirements(ArrayRef<Requirement> requirements) {
10621070 addRequirement (req, /* proto=*/ nullptr , /* requirementID=*/ None);
10631071}
10641072
1065- void RuleBuilder::addRequirements (ArrayRef<StructuralRequirement> requirements) {
1073+ // / For building a rewrite system for a generic signature from user-written
1074+ // / requirements.
1075+ void RuleBuilder::initWithWrittenRequirements (
1076+ ArrayRef<StructuralRequirement> requirements) {
1077+ assert (!Initialized);
1078+ Initialized = 1 ;
1079+
10661080 // Collect all protocols transitively referenced from these requirements.
10671081 for (auto req : requirements) {
10681082 if (req.req .getKind () == RequirementKind::Conformance) {
1069- addProtocol (req.req .getProtocolDecl (), /* initialComponent= */ false );
1083+ addReferencedProtocol (req.req .getProtocolDecl ());
10701084 }
10711085 }
10721086
@@ -1077,16 +1091,117 @@ void RuleBuilder::addRequirements(ArrayRef<StructuralRequirement> requirements)
10771091 addRequirement (req, /* proto=*/ nullptr );
10781092}
10791093
1080- void RuleBuilder::addProtocols (ArrayRef<const ProtocolDecl *> protos) {
1094+ // / For building a rewrite system for a protocol connected component from
1095+ // / a previously-built requirement signature.
1096+ // /
1097+ // / Will trigger requirement signature computation if we haven't built
1098+ // / requirement signatures for this connected component yet, in which case we
1099+ // / will recursively end up building another rewrite system for this component
1100+ // / using initWithProtocolWrittenRequirements().
1101+ void RuleBuilder::initWithProtocolSignatureRequirements (
1102+ ArrayRef<const ProtocolDecl *> protos) {
1103+ assert (!Initialized);
1104+ Initialized = 1 ;
1105+
1106+ // Add all protocols to the referenced set, so that subsequent calls
1107+ // to addReferencedProtocol() with one of these protocols don't add
1108+ // them to the import list.
1109+ for (auto *proto : protos) {
1110+ ReferencedProtocols.insert (proto);
1111+ }
1112+
1113+ for (auto *proto : protos) {
1114+ if (Dump) {
1115+ llvm::dbgs () << " protocol " << proto->getName () << " {\n " ;
1116+ }
1117+
1118+ addPermanentProtocolRules (proto);
1119+
1120+ auto reqs = proto->getRequirementSignature ();
1121+ for (auto req : reqs.getRequirements ())
1122+ addRequirement (req.getCanonical (), proto, /* requirementID=*/ None);
1123+ for (auto alias : reqs.getTypeAliases ())
1124+ addTypeAlias (alias, proto);
1125+
1126+ for (auto *otherProto : proto->getProtocolDependencies ())
1127+ addReferencedProtocol (otherProto);
1128+
1129+ if (Dump) {
1130+ llvm::dbgs () << " }\n " ;
1131+ }
1132+ }
1133+
10811134 // Collect all protocols transitively referenced from this connected component
10821135 // of the protocol dependency graph.
1083- for (auto proto : protos) {
1084- addProtocol (proto, /* initialComponent=*/ true );
1136+ collectRulesFromReferencedProtocols ();
1137+ }
1138+
1139+ // / For building a rewrite system for a protocol connected component from
1140+ // / user-written requirements. Used when actually building requirement
1141+ // / signatures.
1142+ void RuleBuilder::initWithProtocolWrittenRequirements (
1143+ ArrayRef<const ProtocolDecl *> protos) {
1144+ assert (!Initialized);
1145+ Initialized = 1 ;
1146+
1147+ // Add all protocols to the referenced set, so that subsequent calls
1148+ // to addReferencedProtocol() with one of these protocols don't add
1149+ // them to the import list.
1150+ for (auto *proto : protos) {
1151+ ReferencedProtocols.insert (proto);
1152+ }
1153+
1154+ for (auto *proto : protos) {
1155+ if (Dump) {
1156+ llvm::dbgs () << " protocol " << proto->getName () << " {\n " ;
1157+ }
1158+
1159+ addPermanentProtocolRules (proto);
1160+
1161+ for (auto req : proto->getStructuralRequirements ())
1162+ addRequirement (req, proto);
1163+
1164+ for (auto req : proto->getTypeAliasRequirements ())
1165+ addRequirement (req.getCanonical (), proto, /* requirementID=*/ None);
1166+
1167+ for (auto *otherProto : proto->getProtocolDependencies ())
1168+ addReferencedProtocol (otherProto);
1169+
1170+ if (Dump) {
1171+ llvm::dbgs () << " }\n " ;
1172+ }
10851173 }
10861174
1175+ // Collect all protocols transitively referenced from this connected component
1176+ // of the protocol dependency graph.
10871177 collectRulesFromReferencedProtocols ();
10881178}
10891179
1180+ // / Add permanent rules for a protocol, consisting of:
1181+ // /
1182+ // / - The identity conformance rule [P].[P] => [P].
1183+ // / - An associated type introduction rule for each associated type.
1184+ // / - An inherited associated type introduction rule for each associated
1185+ // / type of each inherited protocol.
1186+ void RuleBuilder::addPermanentProtocolRules (const ProtocolDecl *proto) {
1187+ MutableTerm lhs;
1188+ lhs.add (Symbol::forProtocol (proto, Context));
1189+ lhs.add (Symbol::forProtocol (proto, Context));
1190+
1191+ MutableTerm rhs;
1192+ rhs.add (Symbol::forProtocol (proto, Context));
1193+
1194+ PermanentRules.emplace_back (lhs, rhs);
1195+
1196+ for (auto *assocType : proto->getAssociatedTypeMembers ())
1197+ addAssociatedType (assocType, proto);
1198+
1199+ for (auto *inheritedProto : Context.getInheritedProtocols (proto)) {
1200+ for (auto *assocType : inheritedProto->getAssociatedTypeMembers ())
1201+ addAssociatedType (assocType, proto);
1202+ }
1203+ }
1204+
10901205// / For an associated type T in a protocol P, we add a rewrite rule:
10911206// /
10921207// / [P].T => [P:T]
@@ -1264,75 +1379,58 @@ void RuleBuilder::addTypeAlias(const ProtocolTypeAlias &alias,
12641379 /* requirementID=*/ None);
12651380}
12661381
1267- // / Record information about a protocol if we have no seen it yet.
1268- void RuleBuilder::addProtocol (const ProtocolDecl *proto,
1269- bool initialComponent) {
1270- if (ProtocolMap.count (proto) > 0 )
1271- return ;
1272-
1273- ProtocolMap[proto] = initialComponent;
1274- Protocols.push_back (proto);
1382+ // / If we haven't seen this protocol yet, save it for later so that we can
1383+ // / import the rewrite rules from its connected component.
1384+ void RuleBuilder::addReferencedProtocol (const ProtocolDecl *proto) {
1385+ if (ReferencedProtocols.insert (proto).second )
1386+ ProtocolsToImport.push_back (proto);
12751387}
12761388
12771389// / Compute the transitive closure of the set of all protocols referenced from
12781390// / the right hand sides of conformance requirements, and convert their
12791391// / requirements to rewrite rules.
12801392void RuleBuilder::collectRulesFromReferencedProtocols () {
1393+ // Compute the transitive closure.
12811394 unsigned i = 0 ;
1282- while (i < Protocols .size ()) {
1283- auto *proto = Protocols [i++];
1395+ while (i < ProtocolsToImport .size ()) {
1396+ auto *proto = ProtocolsToImport [i++];
12841397 for (auto *depProto : proto->getProtocolDependencies ()) {
1285- addProtocol (depProto, /* initialComponent= */ false );
1398+ addReferencedProtocol (depProto);
12861399 }
12871400 }
12881401
1289- // Add rewrite rules for each protocol.
1290- for (auto *proto : Protocols) {
1402+ // If this is a rewrite system for a generic signature, add rewrite rules for
1403+ // each referenced protocol.
1404+ //
1405+ // if this is a rewrite system for a connected component of the protocol
1406+ // dependency graph, add rewrite rules for each referenced protocol not part
1407+ // of this connected component.
1408+
1409+ // First, collect all unique requirement machines, one for each connected
1410+ // component of each referenced protocol.
1411+ llvm::DenseSet<RequirementMachine *> machines;
1412+
1413+ // Now visit each subordinate requirement machine pull in its rules.
1414+ for (auto *proto : ProtocolsToImport) {
1415+ // This will trigger requirement signature computation for this protocol,
1416+ // if neccessary, which will cause us to re-enter into a new RuleBuilder
1417+ // instace under RuleBuilder::initWithProtocolWrittenRequirements().
12911418 if (Dump) {
1292- llvm::dbgs () << " protocol " << proto->getName () << " {\n " ;
1419+ llvm::dbgs () << " importing protocol " << proto->getName () << " {\n " ;
12931420 }
12941421
1295- // Add the identity conformance rule [P].[P] => [P].
1296- MutableTerm lhs;
1297- lhs.add (Symbol::forProtocol (proto, Context));
1298- lhs.add (Symbol::forProtocol (proto, Context));
1299-
1300- MutableTerm rhs;
1301- rhs.add (Symbol::forProtocol (proto, Context));
1302-
1303- PermanentRules.emplace_back (lhs, rhs);
1304-
1305- for (auto *assocType : proto->getAssociatedTypeMembers ())
1306- addAssociatedType (assocType, proto);
1307-
1308- for (auto *inheritedProto : Context.getInheritedProtocols (proto)) {
1309- for (auto *assocType : inheritedProto->getAssociatedTypeMembers ())
1310- addAssociatedType (assocType, proto);
1311- }
1312-
1313- // If this protocol is part of the initial connected component, we're
1314- // building requirement signatures for all protocols in this component,
1315- // and so we must start with the structural requirements.
1316- //
1317- // Otherwise, we should either already have a requirement signature, or
1318- // we can trigger the computation of the requirement signatures of the
1319- // next component recursively.
1320- if (ProtocolMap[proto]) {
1321- for (auto req : proto->getStructuralRequirements ())
1322- addRequirement (req, proto);
1323-
1324- for (auto req : proto->getTypeAliasRequirements ())
1325- addRequirement (req.getCanonical (), proto, /* requirementID=*/ None);
1326- } else {
1327- auto reqs = proto->getRequirementSignature ();
1328- for (auto req : reqs.getRequirements ())
1329- addRequirement (req.getCanonical (), proto, /* requirementID=*/ None);
1330- for (auto alias : reqs.getTypeAliases ())
1331- addTypeAlias (alias, proto);
1422+ auto *machine = Context.getRequirementMachine (proto);
1423+ if (!machines.insert (machine).second ) {
1424+ // We've already seen this connected component.
1425+ continue ;
13321426 }
13331427
1334- if (Dump) {
1335- llvm::dbgs () << " }\n " ;
1336- }
1428+ // We grab the machine's local rules, not *all* of its rules, to avoid
1429+ // duplicates in case multiple machines share a dependency on a downstream
1430+ // protocol connected component.
1431+ auto localRules = machine->getLocalRules ();
1432+ ImportedRules.insert (ImportedRules.end (),
1433+ localRules.begin (),
1434+ localRules.end ());
13371435 }
13381436}
0 commit comments