@@ -892,6 +892,7 @@ class PrintAST : public ASTVisitor<PrintAST> {
892892 bool openBracket = true , bool closeBracket = true );
893893 void printGenericDeclGenericParams (GenericContext *decl);
894894 void printDeclGenericRequirements (GenericContext *decl);
895+ void printPrimaryAssociatedTypes (ProtocolDecl *decl);
895896 void printBodyIfNecessary (const AbstractFunctionDecl *decl);
896897
897898 void printEnumElement (EnumElementDecl *elt);
@@ -1380,7 +1381,8 @@ struct RequirementPrintLocation {
13801381// / function does: asking "where should this requirement be printed?" and then
13811382// / callers check if the location is the ATD.
13821383static RequirementPrintLocation
1383- bestRequirementPrintLocation (ProtocolDecl *proto, const Requirement &req) {
1384+ bestRequirementPrintLocation (ProtocolDecl *proto, const Requirement &req,
1385+ PrintOptions opts, bool inheritanceClause) {
13841386 auto protoSelf = proto->getProtocolSelfType ();
13851387 // Returns the most relevant decl within proto connected to outerType (or null
13861388 // if one doesn't exist), and whether the type is an "direct use",
@@ -1397,6 +1399,7 @@ bestRequirementPrintLocation(ProtocolDecl *proto, const Requirement &req) {
13971399 return true ;
13981400 } else if (auto DMT = t->getAs <DependentMemberType>()) {
13991401 auto assocType = DMT->getAssocType ();
1402+
14001403 if (assocType && assocType->getProtocol () == proto) {
14011404 relevantDecl = assocType;
14021405 foundType = t;
@@ -1411,6 +1414,17 @@ bestRequirementPrintLocation(ProtocolDecl *proto, const Requirement &req) {
14111414 // If we didn't find anything, relevantDecl and foundType will be null, as
14121415 // desired.
14131416 auto directUse = foundType && outerType->isEqual (foundType);
1417+
1418+ // Prefer to attach requirements to associated type declarations,
1419+ // unless the associated type is a primary associated type and
1420+ // we're printing primary associated types using the new syntax.
1421+ if (!directUse &&
1422+ relevantDecl &&
1423+ opts.PrintPrimaryAssociatedTypes &&
1424+ isa<AssociatedTypeDecl>(relevantDecl) &&
1425+ cast<AssociatedTypeDecl>(relevantDecl)->isPrimary ())
1426+ relevantDecl = proto;
1427+
14141428 return std::make_pair (relevantDecl, directUse);
14151429 };
14161430
@@ -1481,7 +1495,8 @@ void PrintAST::printInheritedFromRequirementSignature(ProtocolDecl *proto,
14811495 return false ;
14821496 }
14831497
1484- auto location = bestRequirementPrintLocation (proto, req);
1498+ auto location = bestRequirementPrintLocation (proto, req, Options,
1499+ /* inheritanceClause=*/ true );
14851500 return location.AttachedTo == attachingTo && !location.InWhereClause ;
14861501 });
14871502}
@@ -1496,7 +1511,8 @@ void PrintAST::printWhereClauseFromRequirementSignature(ProtocolDecl *proto,
14961511 proto->getRequirementSignature ().getRequirements ()),
14971512 flags,
14981513 [&](const Requirement &req) {
1499- auto location = bestRequirementPrintLocation (proto, req);
1514+ auto location = bestRequirementPrintLocation (proto, req, Options,
1515+ /* inheritanceClause=*/ false );
15001516 return location.AttachedTo == attachingTo && location.InWhereClause ;
15011517 });
15021518}
@@ -2969,6 +2985,22 @@ static void suppressingFeatureUnsafeInheritExecutor(PrintOptions &options,
29692985 options.ExcludeAttrList .resize (originalExcludeAttrCount);
29702986}
29712987
2988+ static bool usesFeaturePrimaryAssociatedTypes (Decl *decl) {
2989+ if (auto *protoDecl = dyn_cast<ProtocolDecl>(decl)) {
2990+ if (protoDecl->getPrimaryAssociatedTypes ().size () > 0 )
2991+ return true ;
2992+ }
2993+
2994+ return false ;
2995+ }
2996+
2997+ static void suppressingFeaturePrimaryAssociatedTypes (PrintOptions &options,
2998+ llvm::function_ref<void ()> action) {
2999+ bool originalPrintPrimaryAssociatedTypes = options.PrintPrimaryAssociatedTypes ;
3000+ options.PrintPrimaryAssociatedTypes = false ;
3001+ action ();
3002+ options.PrintPrimaryAssociatedTypes = originalPrintPrimaryAssociatedTypes;
3003+ }
29723004
29733005// / Suppress the printing of a particular feature.
29743006static void suppressingFeature (PrintOptions &options, Feature feature,
@@ -3485,6 +3517,38 @@ void PrintAST::visitClassDecl(ClassDecl *decl) {
34853517 }
34863518}
34873519
3520+ void PrintAST::printPrimaryAssociatedTypes (ProtocolDecl *decl) {
3521+ auto primaryAssocTypes = decl->getPrimaryAssociatedTypes ();
3522+ if (primaryAssocTypes.empty ())
3523+ return ;
3524+
3525+ Printer.printStructurePre (PrintStructureKind::DeclGenericParameterClause);
3526+
3527+ Printer << " <" ;
3528+ llvm::interleave (
3529+ primaryAssocTypes,
3530+ [&](AssociatedTypeDecl *assocType) {
3531+ Printer.callPrintStructurePre (PrintStructureKind::GenericParameter,
3532+ assocType);
3533+ Printer.printName (assocType->getName (),
3534+ PrintNameContext::GenericParameter);
3535+
3536+ printInheritedFromRequirementSignature (decl, assocType);
3537+
3538+ if (assocType->hasDefaultDefinitionType ()) {
3539+ Printer << " = " ;
3540+ assocType->getDefaultDefinitionType ().print (Printer, Options);
3541+ }
3542+
3543+ Printer.printStructurePost (PrintStructureKind::GenericParameter,
3544+ assocType);
3545+ },
3546+ [&] { Printer << " , " ; });
3547+ Printer << " >" ;
3548+
3549+ Printer.printStructurePost (PrintStructureKind::DeclGenericParameterClause);
3550+ }
3551+
34883552void PrintAST::visitProtocolDecl (ProtocolDecl *decl) {
34893553 printDocumentationComment (decl);
34903554 printAttributes (decl);
@@ -3502,6 +3566,10 @@ void PrintAST::visitProtocolDecl(ProtocolDecl *decl) {
35023566 Printer.printName (decl->getName ());
35033567 });
35043568
3569+ if (Options.PrintPrimaryAssociatedTypes ) {
3570+ printPrimaryAssociatedTypes (decl);
3571+ }
3572+
35053573 printInheritedFromRequirementSignature (decl, decl);
35063574
35073575 // The trailing where clause is a syntactic thing, which isn't serialized
@@ -4997,6 +5065,14 @@ bool Decl::shouldPrintInContext(const PrintOptions &PO) const {
49975065 return PO.PrintIfConfig ;
49985066 }
49995067
5068+ if (auto *ATD = dyn_cast<AssociatedTypeDecl>(this )) {
5069+ // If PO.PrintPrimaryAssociatedTypes is on, primary associated
5070+ // types are printed as part of the protocol declaration itself,
5071+ // so skip them here.
5072+ if (ATD->isPrimary () && PO.PrintPrimaryAssociatedTypes )
5073+ return false ;
5074+ }
5075+
50005076 // Print everything else.
50015077 return true ;
50025078}
0 commit comments