@@ -3382,23 +3382,40 @@ static bool hasInverse(
33823382 if (auto *extension = dyn_cast<ExtensionDecl>(decl)) {
33833383 if (auto *nominal = extension->getSelfNominalTypeDecl ())
33843384 return hasInverse (nominal, ip, isRelevantInverse);
3385+ return false ;
33853386 }
33863387
3387- if (auto *TD = dyn_cast<TypeDecl>(decl))
3388- return isRelevantInverse (TD->hasInverseMarking (ip));
3388+ auto hasInverseInType = [&](Type type) {
3389+ return type.findIf ([&](Type type) -> bool {
3390+ if (auto *typeDecl = getTypeDecl (type))
3391+ return hasInverse (typeDecl, ip, isRelevantInverse);
3392+ return false ;
3393+ });
3394+ };
33893395
3390- if (auto value = dyn_cast<ValueDecl>(decl)) {
3391- // Check for noncopyable types in the types of this declaration.
3392- if (Type type = value->getInterfaceType ()) {
3393- bool foundInverse = type.findIf ([&](Type type) -> bool {
3394- if (auto *typeDecl = getTypeDecl (type))
3395- return hasInverse (typeDecl, ip, isRelevantInverse);
3396- return false ;
3397- });
3396+ if (auto *TD = dyn_cast<TypeDecl>(decl)) {
3397+ if (auto *alias = dyn_cast<TypeAliasDecl>(TD))
3398+ return hasInverseInType (alias->getUnderlyingType ());
33983399
3399- if (foundInverse)
3400+ if (auto *NTD = dyn_cast<NominalTypeDecl>(TD)) {
3401+ if (isRelevantInverse (NTD->hasInverseMarking (ip)))
34003402 return true ;
34013403 }
3404+
3405+ if (auto *P = dyn_cast<ProtocolDecl>(TD)) {
3406+ // Check the protocol's associated types too.
3407+ return llvm::any_of (
3408+ P->getAssociatedTypeMembers (), [&](AssociatedTypeDecl *ATD) {
3409+ return isRelevantInverse (ATD->hasInverseMarking (ip));
3410+ });
3411+ }
3412+
3413+ return false ;
3414+ }
3415+
3416+ if (auto *VD = dyn_cast<ValueDecl>(decl)) {
3417+ if (VD->hasInterfaceType ())
3418+ return hasInverseInType (VD->getInterfaceType ());
34023419 }
34033420
34043421 return false ;
0 commit comments