@@ -1999,6 +1999,211 @@ bool swift::isAsyncDecl(ConcreteDeclRef declRef) {
19991999 return false ;
20002000}
20012001
2002+ AbstractFunctionDecl *swift::enclosingUnsafeInheritsExecutor (
2003+ const DeclContext *dc) {
2004+ for (; dc; dc = dc->getParent ()) {
2005+ if (auto func = dyn_cast<AbstractFunctionDecl>(dc)) {
2006+ if (func->getAttrs ().hasAttribute <UnsafeInheritExecutorAttr>()) {
2007+ return const_cast <AbstractFunctionDecl *>(func);
2008+ }
2009+
2010+ return nullptr ;
2011+ }
2012+
2013+ if (isa<AbstractClosureExpr>(dc))
2014+ return nullptr ;
2015+
2016+ if (dc->isTypeContext ())
2017+ return nullptr ;
2018+ }
2019+
2020+ return nullptr ;
2021+ }
2022+
2023+ // / Adjust the location used for diagnostics about #isolation to account for
2024+ // / the fact that they show up in macro expansions.
2025+ // /
2026+ // / Returns a pair containing the updated location and whether it's part of
2027+ // / a default argument.
2028+ static std::pair<SourceLoc, bool > adjustPoundIsolationDiagLoc (
2029+ CurrentContextIsolationExpr *isolationExpr,
2030+ ModuleDecl *module
2031+ ) {
2032+ // Not part of a macro expansion.
2033+ SourceLoc diagLoc = isolationExpr->getLoc ();
2034+ auto sourceFile = module ->getSourceFileContainingLocation (diagLoc);
2035+ if (!sourceFile)
2036+ return { diagLoc, false };
2037+ auto macroExpansionRange = sourceFile->getMacroInsertionRange ();
2038+ if (macroExpansionRange.Start .isInvalid ())
2039+ return { diagLoc, false };
2040+
2041+ diagLoc = macroExpansionRange.Start ;
2042+
2043+ // If this is from a default argument, note that and go one more
2044+ // level "out" to the place where the default argument was
2045+ // introduced.
2046+ auto expansionSourceFile = module ->getSourceFileContainingLocation (diagLoc);
2047+ if (!expansionSourceFile ||
2048+ expansionSourceFile->Kind != SourceFileKind::DefaultArgument)
2049+ return { diagLoc, false };
2050+
2051+ return {
2052+ expansionSourceFile->getNodeInEnclosingSourceFile ().getStartLoc (),
2053+ true
2054+ };
2055+ }
2056+
2057+ void swift::replaceUnsafeInheritExecutorWithDefaultedIsolationParam (
2058+ AbstractFunctionDecl *func, InFlightDiagnostic &diag) {
2059+ auto attr = func->getAttrs ().getAttribute <UnsafeInheritExecutorAttr>();
2060+ assert (attr && " Caller didn't validate the presence of the attribute" );
2061+
2062+ // Look for the place where we should insert the new 'isolation' parameter.
2063+ // We insert toward the back, but skip over any parameters that have function
2064+ // type.
2065+ unsigned insertionPos = func->getParameters ()->size ();
2066+ while (insertionPos > 0 ) {
2067+ Type paramType = func->getParameters ()->get (insertionPos - 1 )->getInterfaceType ();
2068+ if (paramType->lookThroughSingleOptionalType ()->is <AnyFunctionType>()) {
2069+ --insertionPos;
2070+ continue ;
2071+ }
2072+
2073+ break ;
2074+ }
2075+
2076+ // Determine the text to insert. We put the commas before and after, then
2077+ // slice them away depending on whether we have parameters before or after.
2078+ StringRef newParameterText = " , isolation: isolated (any Actor)? = #isolation, " ;
2079+ if (insertionPos == 0 )
2080+ newParameterText = newParameterText.drop_front (2 );
2081+ if (insertionPos == func->getParameters ()->size ())
2082+ newParameterText = newParameterText.drop_back (2 );
2083+
2084+ // Determine where to insert the new parameter.
2085+ SourceLoc insertionLoc;
2086+ if (insertionPos < func->getParameters ()->size ()) {
2087+ insertionLoc = func->getParameters ()->get (insertionPos)->getStartLoc ();
2088+ } else {
2089+ insertionLoc = func->getParameters ()->getRParenLoc ();
2090+ }
2091+
2092+ diag.fixItRemove (attr->getRangeWithAt ());
2093+ diag.fixItInsert (insertionLoc, newParameterText);
2094+ }
2095+
2096+ // / Whether this declaration context is in the _Concurrency module.
2097+ static bool inConcurrencyModule (const DeclContext *dc) {
2098+ return dc->getParentModule ()->getName ().str ().equals (" _Concurrency" );
2099+ }
2100+
2101+ void swift::introduceUnsafeInheritExecutorReplacements (
2102+ const DeclContext *dc, SourceLoc loc, SmallVectorImpl<ValueDecl *> &decls) {
2103+ if (decls.empty ())
2104+ return ;
2105+
2106+ auto isReplaceable = [&](ValueDecl *decl) {
2107+ return isa<FuncDecl>(decl) && inConcurrencyModule (decl->getDeclContext ()) &&
2108+ decl->getDeclContext ()->isModuleScopeContext ();
2109+ };
2110+
2111+ // Make sure at least some of the entries are functions in the _Concurrency
2112+ // module.
2113+ ModuleDecl *concurrencyModule = nullptr ;
2114+ DeclBaseName baseName;
2115+ for (auto decl: decls) {
2116+ if (isReplaceable (decl)) {
2117+ concurrencyModule = decl->getDeclContext ()->getParentModule ();
2118+ baseName = decl->getName ().getBaseName ();
2119+ break ;
2120+ }
2121+ }
2122+ if (!concurrencyModule)
2123+ return ;
2124+
2125+ // Ignore anything with a special name.
2126+ if (baseName.isSpecial ())
2127+ return ;
2128+
2129+ // Look for entities with the _unsafeInheritExecutor_ prefix on the name.
2130+ ASTContext &ctx = decls.front ()->getASTContext ();
2131+ Identifier newIdentifier = ctx.getIdentifier (
2132+ (" _unsafeInheritExecutor_" + baseName.getIdentifier ().str ()).str ());
2133+
2134+ NameLookupOptions lookupOptions = defaultUnqualifiedLookupOptions;
2135+ LookupResult lookup = TypeChecker::lookupUnqualified (
2136+ const_cast <DeclContext *>(dc), DeclNameRef (newIdentifier), loc,
2137+ lookupOptions);
2138+ if (!lookup)
2139+ return ;
2140+
2141+ // Drop all of the _Concurrency entries in favor of the ones found by this
2142+ // lookup.
2143+ decls.erase (std::remove_if (decls.begin (), decls.end (), [&](ValueDecl *decl) {
2144+ return isReplaceable (decl);
2145+ }),
2146+ decls.end ());
2147+ for (const auto &lookupResult: lookup) {
2148+ if (auto decl = lookupResult.getValueDecl ())
2149+ decls.push_back (decl);
2150+ }
2151+ }
2152+
2153+ void swift::introduceUnsafeInheritExecutorReplacements (
2154+ const DeclContext *dc, Type base, SourceLoc loc, LookupResult &lookup) {
2155+ if (lookup.empty ())
2156+ return ;
2157+
2158+ auto baseNominal = base->getAnyNominal ();
2159+ if (!baseNominal || !inConcurrencyModule (baseNominal))
2160+ return ;
2161+
2162+ auto isReplaceable = [&](ValueDecl *decl) {
2163+ return isa<FuncDecl>(decl) && inConcurrencyModule (decl->getDeclContext ());
2164+ };
2165+
2166+ // Make sure at least some of the entries are functions in the _Concurrency
2167+ // module.
2168+ ModuleDecl *concurrencyModule = nullptr ;
2169+ DeclBaseName baseName;
2170+ for (auto &result: lookup) {
2171+ auto decl = result.getValueDecl ();
2172+ if (isReplaceable (decl)) {
2173+ concurrencyModule = decl->getDeclContext ()->getParentModule ();
2174+ baseName = decl->getBaseName ();
2175+ break ;
2176+ }
2177+ }
2178+ if (!concurrencyModule)
2179+ return ;
2180+
2181+ // Ignore anything with a special name.
2182+ if (baseName.isSpecial ())
2183+ return ;
2184+
2185+ // Look for entities with the _unsafeInheritExecutor_ prefix on the name.
2186+ ASTContext &ctx = base->getASTContext ();
2187+ Identifier newIdentifier = ctx.getIdentifier (
2188+ (" _unsafeInheritExecutor_" + baseName.getIdentifier ().str ()).str ());
2189+
2190+ LookupResult replacementLookup = TypeChecker::lookupMember (
2191+ const_cast <DeclContext *>(dc), base, DeclNameRef (newIdentifier), loc,
2192+ defaultMemberLookupOptions);
2193+ if (replacementLookup.innerResults ().empty ())
2194+ return ;
2195+
2196+ // Drop all of the _Concurrency entries in favor of the ones found by this
2197+ // lookup.
2198+ lookup.filter ([&](const LookupResultEntry &entry, bool ) {
2199+ return !isReplaceable (entry.getValueDecl ());
2200+ });
2201+
2202+ for (const auto &entry: replacementLookup.innerResults ()) {
2203+ lookup.add (entry, /* isOuter=*/ false );
2204+ }
2205+ }
2206+
20022207// / Check if it is safe for the \c globalActor qualifier to be removed from
20032208// / \c ty, when the function value of that type is isolated to that actor.
20042209// /
@@ -3748,6 +3953,23 @@ namespace {
37483953 if (isolationExpr->getActor ())
37493954 return ;
37503955
3956+ // #isolation does not work within an `@_unsafeInheritExecutor` function.
3957+ if (auto func = enclosingUnsafeInheritsExecutor (getDeclContext ())) {
3958+ // This expression is always written as a macro #isolation in source,
3959+ // so find the enclosing macro expansion expression's location.
3960+ SourceLoc diagLoc;
3961+ bool inDefaultArgument;
3962+ std::tie (diagLoc, inDefaultArgument) = adjustPoundIsolationDiagLoc (
3963+ isolationExpr, getDeclContext ()->getParentModule ());
3964+
3965+ bool inConcurrencyModule = ::inConcurrencyModule (getDeclContext ());
3966+ auto diag = ctx.Diags .diagnose (diagLoc,
3967+ diag::isolation_in_inherits_executor,
3968+ inDefaultArgument);
3969+ diag.limitBehaviorIf (inConcurrencyModule, DiagnosticBehavior::Warning);
3970+ replaceUnsafeInheritExecutorWithDefaultedIsolationParam (func, diag);
3971+ }
3972+
37513973 auto loc = isolationExpr->getLoc ();
37523974 auto isolation = getActorIsolationOfContext (
37533975 const_cast <DeclContext *>(getDeclContext ()),
0 commit comments