@@ -4903,10 +4903,10 @@ GenericParameterReferenceInfo ValueDecl::findExistentialSelfReferences(
49034903
49044904InverseMarking::Mark
49054905TypeDecl::hasInverseMarking (InvertibleProtocolKind target) const {
4906- if (auto P = dyn_cast<ProtocolDecl >(this ))
4907- return P ->hasInverseMarking (target);
4906+ if (auto NTD = dyn_cast<NominalTypeDecl >(this ))
4907+ return NTD ->hasInverseMarking (target);
49084908
4909- return getMarking (target). getInverse ( );
4909+ return InverseMarking::Mark (InverseMarking::Kind::None );
49104910}
49114911
49124912InverseMarking TypeDecl::getMarking (InvertibleProtocolKind ip) const {
@@ -6620,31 +6620,199 @@ bool ProtocolDecl::inheritsFrom(const ProtocolDecl *super) const {
66206620 });
66216621}
66226622
6623+ static void findInheritedType (
6624+ InheritedTypes inherited,
6625+ llvm::function_ref<bool (Type, NullablePtr<TypeRepr>)> isMatch) {
6626+ for (size_t i = 0 ; i < inherited.size (); i++) {
6627+ auto type = inherited.getResolvedType (i, TypeResolutionStage::Structural);
6628+ if (!type)
6629+ continue ;
6630+
6631+ if (isMatch (type, inherited.getTypeRepr (i)))
6632+ break ;
6633+ }
6634+ }
6635+
6636+ static InverseMarking::Mark
6637+ findInverseInInheritance (InheritedTypes inherited,
6638+ InvertibleProtocolKind target) {
6639+ auto isInverseOfTarget = [&](Type t) {
6640+ if (auto pct = t->getAs <ProtocolCompositionType>())
6641+ return pct->getInverses ().contains (target);
6642+ return false ;
6643+ };
6644+
6645+ InverseMarking::Mark inverse;
6646+ findInheritedType (inherited,
6647+ [&](Type inheritedTy, NullablePtr<TypeRepr> repr) {
6648+ if (!isInverseOfTarget (inheritedTy))
6649+ return false ;
6650+
6651+ inverse = InverseMarking::Mark (
6652+ InverseMarking::Kind::Explicit,
6653+ repr.isNull () ? SourceLoc () : repr.get ()->getLoc ());
6654+ return true ;
6655+ });
6656+ return inverse;
6657+ }
6658+
6659+ bool NominalTypeDecl::hasMarking (InvertibleProtocolKind target) const {
6660+ InverseMarking::Mark mark;
6661+
6662+ std::function<bool (Type)> isTarget = [&](Type t) -> bool {
6663+ if (auto kp = t->getKnownProtocol ()) {
6664+ if (auto ip = getInvertibleProtocolKind (*kp))
6665+ return *ip == target;
6666+ } else if (auto pct = t->getAs <ProtocolCompositionType>()) {
6667+ return llvm::any_of (pct->getMembers (), isTarget);
6668+ }
6669+
6670+ return false ;
6671+ };
6672+
6673+ findInheritedType (getInherited (),
6674+ [&](Type inheritedTy, NullablePtr<TypeRepr> repr) {
6675+ if (!isTarget (inheritedTy))
6676+ return false ;
6677+
6678+ mark = InverseMarking::Mark (
6679+ InverseMarking::Kind::Explicit,
6680+ repr.isNull () ? SourceLoc () : repr.get ()->getLoc ());
6681+ return true ;
6682+ });
6683+ return mark;
6684+ }
6685+
66236686InverseMarking::Mark
6624- ProtocolDecl::hasInverseMarking (InvertibleProtocolKind target) const {
6687+ NominalTypeDecl::hasInverseMarking (InvertibleProtocolKind target) const {
6688+ switch (target) {
6689+ case InvertibleProtocolKind::Copyable:
6690+ // Handle the legacy '@_moveOnly' for types they can validly appear.
6691+ // TypeCheckAttr handles the illegal situations for us.
6692+ if (auto attr = getAttrs ().getAttribute <MoveOnlyAttr>())
6693+ if (isa<StructDecl, EnumDecl, ClassDecl>(this ))
6694+ return InverseMarking::Mark (InverseMarking::Kind::LegacyExplicit,
6695+ attr->getLocation ());
6696+ break ;
6697+
6698+ case InvertibleProtocolKind::Escapable:
6699+ // Handle the legacy '@_nonEscapable' attribute
6700+ if (auto attr = getAttrs ().getAttribute <NonEscapableAttr>()) {
6701+ assert ((isa<ClassDecl, StructDecl, EnumDecl>(this )));
6702+ return InverseMarking::Mark (InverseMarking::Kind::LegacyExplicit,
6703+ attr->getLocation ());
6704+ }
6705+ break ;
6706+ }
6707+
66256708 auto &ctx = getASTContext ();
66266709
66276710 // Legacy support stops here.
66286711 if (!ctx.LangOpts .hasFeature (Feature::NoncopyableGenerics))
6712+ return InverseMarking::Mark (InverseMarking::Kind::None);
6713+
6714+ // Claim that the tuple decl has an inferred ~TARGET marking.
6715+ if (isa<BuiltinTupleDecl>(this ))
6716+ return InverseMarking::Mark (InverseMarking::Kind::Inferred);
6717+
6718+ if (auto P = dyn_cast<ProtocolDecl>(this ))
6719+ return P->hasInverseMarking (target);
6720+
6721+ // Search the inheritance clause first.
6722+ if (auto inverse = findInverseInInheritance (getInherited (), target))
6723+ return inverse;
6724+
6725+ // Check the generic parameters for an explicit ~TARGET marking
6726+ // which would result in an Inferred ~TARGET marking for this context.
6727+ auto *gpList = getParsedGenericParams ();
6728+ if (!gpList)
66296729 return InverseMarking::Mark ();
66306730
6631- auto inheritedTypes = getInherited ();
6632- for (unsigned i = 0 ; i < inheritedTypes.size (); ++i) {
6633- auto type =
6634- inheritedTypes.getResolvedType (i, TypeResolutionStage::Structural);
6635- if (!type)
6731+ auto isInverseTarget = [&](Type t) -> bool {
6732+ if (auto pct = t->getAs <ProtocolCompositionType>())
6733+ return pct->getInverses ().contains (target);
6734+ return false ;
6735+ };
6736+
6737+ auto resolveRequirement = [&](unsigned reqIdx) -> std::optional<Requirement> {
6738+ WhereClauseOwner owner (const_cast <NominalTypeDecl *>(this ));
6739+ auto req = ctx.evaluator (
6740+ RequirementRequest{owner, reqIdx, TypeResolutionStage::Structural},
6741+ [&]() {
6742+ return Requirement (RequirementKind::SameType, ErrorType::get (ctx),
6743+ ErrorType::get (ctx));
6744+ });
6745+
6746+ if (req.hasError ())
6747+ return std::nullopt ;
6748+
6749+ return req;
6750+ };
6751+
6752+ llvm::SmallSet<GenericTypeParamDecl *, 4 > params;
6753+
6754+ // Scan the inheritance clauses of generic parameters only for an inverse.
6755+ for (GenericTypeParamDecl *param : gpList->getParams ()) {
6756+ auto inverse = findInverseInInheritance (param->getInherited (), target);
6757+
6758+ // Inverse is inferred from one of the generic parameters.
6759+ if (inverse)
6760+ return inverse.with (InverseMarking::Kind::Inferred);
6761+
6762+ params.insert (param);
6763+ }
6764+
6765+ // Next, scan the where clause and return the result.
6766+ auto whereClause = getTrailingWhereClause ();
6767+ if (!whereClause)
6768+ return InverseMarking::Mark ();
6769+
6770+ auto requirements = whereClause->getRequirements ();
6771+ for (unsigned i : indices (requirements)) {
6772+ auto requirementRepr = requirements[i];
6773+ if (requirementRepr.getKind () != RequirementReprKind::TypeConstraint)
66366774 continue ;
66376775
6638- auto *repr = inheritedTypes.getTypeRepr (i);
6776+ auto *constraintRepr =
6777+ dyn_cast<InverseTypeRepr>(requirementRepr.getConstraintRepr ());
6778+ if (!constraintRepr || constraintRepr->isInvalid ())
6779+ continue ;
66396780
6640- if (auto *composition = type->getAs <ProtocolCompositionType>()) {
6641- // Found ~<target> in the protocol inheritance clause.
6642- if (composition->getInverses ().contains (target))
6643- return InverseMarking::Mark (InverseMarking::Kind::Explicit,
6644- repr ? repr->getLoc () : SourceLoc ());
6645- }
6781+ auto req = resolveRequirement (i);
6782+ if (!req)
6783+ continue ;
6784+
6785+ if (req->getKind () != RequirementKind::Conformance)
6786+ continue ;
6787+
6788+ auto subject = req->getFirstType ();
6789+ if (!subject->isTypeParameter ())
6790+ continue ;
6791+
6792+ // Skip outer params and implicit ones.
6793+ auto *param = subject->getRootGenericParam ()->getDecl ();
6794+ if (!param || !params.contains (param))
6795+ continue ;
6796+
6797+ if (isInverseTarget (req->getSecondType ()))
6798+ return InverseMarking::Mark (InverseMarking::Kind::Inferred,
6799+ constraintRepr->getLoc ());
66466800 }
66476801
6802+ return InverseMarking::Mark ();
6803+ }
6804+
6805+ InverseMarking::Mark
6806+ ProtocolDecl::hasInverseMarking (InvertibleProtocolKind target) const {
6807+ auto &ctx = getASTContext ();
6808+
6809+ // Legacy support stops here.
6810+ if (!ctx.LangOpts .hasFeature (Feature::NoncopyableGenerics))
6811+ return InverseMarking::Mark ();
6812+
6813+ if (auto inverse = findInverseInInheritance (getInherited (), target))
6814+ return inverse;
6815+
66486816 auto *whereClause = getTrailingWhereClause ();
66496817 if (!whereClause)
66506818 return InverseMarking::Mark ();
0 commit comments