@@ -2002,6 +2002,211 @@ bool swift::isAsyncDecl(ConcreteDeclRef declRef) {
20022002 return false ;
20032003}
20042004
2005+ AbstractFunctionDecl *swift::enclosingUnsafeInheritsExecutor (
2006+ const DeclContext *dc) {
2007+ for (; dc; dc = dc->getParent ()) {
2008+ if (auto func = dyn_cast<AbstractFunctionDecl>(dc)) {
2009+ if (func->getAttrs ().hasAttribute <UnsafeInheritExecutorAttr>()) {
2010+ return const_cast <AbstractFunctionDecl *>(func);
2011+ }
2012+
2013+ return nullptr ;
2014+ }
2015+
2016+ if (isa<AbstractClosureExpr>(dc))
2017+ return nullptr ;
2018+
2019+ if (dc->isTypeContext ())
2020+ return nullptr ;
2021+ }
2022+
2023+ return nullptr ;
2024+ }
2025+
2026+ // / Adjust the location used for diagnostics about #isolation to account for
2027+ // / the fact that they show up in macro expansions.
2028+ // /
2029+ // / Returns a pair containing the updated location and whether it's part of
2030+ // / a default argument.
2031+ static std::pair<SourceLoc, bool > adjustPoundIsolationDiagLoc (
2032+ CurrentContextIsolationExpr *isolationExpr,
2033+ ModuleDecl *module
2034+ ) {
2035+ // Not part of a macro expansion.
2036+ SourceLoc diagLoc = isolationExpr->getLoc ();
2037+ auto sourceFile = module ->getSourceFileContainingLocation (diagLoc);
2038+ if (!sourceFile)
2039+ return { diagLoc, false };
2040+ auto macroExpansionRange = sourceFile->getMacroInsertionRange ();
2041+ if (macroExpansionRange.Start .isInvalid ())
2042+ return { diagLoc, false };
2043+
2044+ diagLoc = macroExpansionRange.Start ;
2045+
2046+ // If this is from a default argument, note that and go one more
2047+ // level "out" to the place where the default argument was
2048+ // introduced.
2049+ auto expansionSourceFile = module ->getSourceFileContainingLocation (diagLoc);
2050+ if (!expansionSourceFile ||
2051+ expansionSourceFile->Kind != SourceFileKind::DefaultArgument)
2052+ return { diagLoc, false };
2053+
2054+ return {
2055+ expansionSourceFile->getNodeInEnclosingSourceFile ().getStartLoc (),
2056+ true
2057+ };
2058+ }
2059+
2060+ void swift::replaceUnsafeInheritExecutorWithDefaultedIsolationParam (
2061+ AbstractFunctionDecl *func, InFlightDiagnostic &diag) {
2062+ auto attr = func->getAttrs ().getAttribute <UnsafeInheritExecutorAttr>();
2063+ assert (attr && " Caller didn't validate the presence of the attribute" );
2064+
2065+ // Look for the place where we should insert the new 'isolation' parameter.
2066+ // We insert toward the back, but skip over any parameters that have function
2067+ // type.
2068+ unsigned insertionPos = func->getParameters ()->size ();
2069+ while (insertionPos > 0 ) {
2070+ Type paramType = func->getParameters ()->get (insertionPos - 1 )->getInterfaceType ();
2071+ if (paramType->lookThroughSingleOptionalType ()->is <AnyFunctionType>()) {
2072+ --insertionPos;
2073+ continue ;
2074+ }
2075+
2076+ break ;
2077+ }
2078+
2079+ // Determine the text to insert. We put the commas before and after, then
2080+ // slice them away depending on whether we have parameters before or after.
2081+ StringRef newParameterText = " , isolation: isolated (any Actor)? = #isolation, " ;
2082+ if (insertionPos == 0 )
2083+ newParameterText = newParameterText.drop_front (2 );
2084+ if (insertionPos == func->getParameters ()->size ())
2085+ newParameterText = newParameterText.drop_back (2 );
2086+
2087+ // Determine where to insert the new parameter.
2088+ SourceLoc insertionLoc;
2089+ if (insertionPos < func->getParameters ()->size ()) {
2090+ insertionLoc = func->getParameters ()->get (insertionPos)->getStartLoc ();
2091+ } else {
2092+ insertionLoc = func->getParameters ()->getRParenLoc ();
2093+ }
2094+
2095+ diag.fixItRemove (attr->getRangeWithAt ());
2096+ diag.fixItInsert (insertionLoc, newParameterText);
2097+ }
2098+
2099+ // / Whether this declaration context is in the _Concurrency module.
2100+ static bool inConcurrencyModule (const DeclContext *dc) {
2101+ return dc->getParentModule ()->getName ().str ().equals (" _Concurrency" );
2102+ }
2103+
2104+ void swift::introduceUnsafeInheritExecutorReplacements (
2105+ const DeclContext *dc, SourceLoc loc, SmallVectorImpl<ValueDecl *> &decls) {
2106+ if (decls.empty ())
2107+ return ;
2108+
2109+ auto isReplaceable = [&](ValueDecl *decl) {
2110+ return isa<FuncDecl>(decl) && inConcurrencyModule (decl->getDeclContext ()) &&
2111+ decl->getDeclContext ()->isModuleScopeContext ();
2112+ };
2113+
2114+ // Make sure at least some of the entries are functions in the _Concurrency
2115+ // module.
2116+ ModuleDecl *concurrencyModule = nullptr ;
2117+ DeclBaseName baseName;
2118+ for (auto decl: decls) {
2119+ if (isReplaceable (decl)) {
2120+ concurrencyModule = decl->getDeclContext ()->getParentModule ();
2121+ baseName = decl->getName ().getBaseName ();
2122+ break ;
2123+ }
2124+ }
2125+ if (!concurrencyModule)
2126+ return ;
2127+
2128+ // Ignore anything with a special name.
2129+ if (baseName.isSpecial ())
2130+ return ;
2131+
2132+ // Look for entities with the _unsafeInheritExecutor_ prefix on the name.
2133+ ASTContext &ctx = decls.front ()->getASTContext ();
2134+ Identifier newIdentifier = ctx.getIdentifier (
2135+ (" _unsafeInheritExecutor_" + baseName.getIdentifier ().str ()).str ());
2136+
2137+ NameLookupOptions lookupOptions = defaultUnqualifiedLookupOptions;
2138+ LookupResult lookup = TypeChecker::lookupUnqualified (
2139+ const_cast <DeclContext *>(dc), DeclNameRef (newIdentifier), loc,
2140+ lookupOptions);
2141+ if (!lookup)
2142+ return ;
2143+
2144+ // Drop all of the _Concurrency entries in favor of the ones found by this
2145+ // lookup.
2146+ decls.erase (std::remove_if (decls.begin (), decls.end (), [&](ValueDecl *decl) {
2147+ return isReplaceable (decl);
2148+ }),
2149+ decls.end ());
2150+ for (const auto &lookupResult: lookup) {
2151+ if (auto decl = lookupResult.getValueDecl ())
2152+ decls.push_back (decl);
2153+ }
2154+ }
2155+
2156+ void swift::introduceUnsafeInheritExecutorReplacements (
2157+ const DeclContext *dc, Type base, SourceLoc loc, LookupResult &lookup) {
2158+ if (lookup.empty ())
2159+ return ;
2160+
2161+ auto baseNominal = base->getAnyNominal ();
2162+ if (!baseNominal || !inConcurrencyModule (baseNominal))
2163+ return ;
2164+
2165+ auto isReplaceable = [&](ValueDecl *decl) {
2166+ return isa<FuncDecl>(decl) && inConcurrencyModule (decl->getDeclContext ());
2167+ };
2168+
2169+ // Make sure at least some of the entries are functions in the _Concurrency
2170+ // module.
2171+ ModuleDecl *concurrencyModule = nullptr ;
2172+ DeclBaseName baseName;
2173+ for (auto &result: lookup) {
2174+ auto decl = result.getValueDecl ();
2175+ if (isReplaceable (decl)) {
2176+ concurrencyModule = decl->getDeclContext ()->getParentModule ();
2177+ baseName = decl->getBaseName ();
2178+ break ;
2179+ }
2180+ }
2181+ if (!concurrencyModule)
2182+ return ;
2183+
2184+ // Ignore anything with a special name.
2185+ if (baseName.isSpecial ())
2186+ return ;
2187+
2188+ // Look for entities with the _unsafeInheritExecutor_ prefix on the name.
2189+ ASTContext &ctx = base->getASTContext ();
2190+ Identifier newIdentifier = ctx.getIdentifier (
2191+ (" _unsafeInheritExecutor_" + baseName.getIdentifier ().str ()).str ());
2192+
2193+ LookupResult replacementLookup = TypeChecker::lookupMember (
2194+ const_cast <DeclContext *>(dc), base, DeclNameRef (newIdentifier), loc,
2195+ defaultMemberLookupOptions);
2196+ if (replacementLookup.innerResults ().empty ())
2197+ return ;
2198+
2199+ // Drop all of the _Concurrency entries in favor of the ones found by this
2200+ // lookup.
2201+ lookup.filter ([&](const LookupResultEntry &entry, bool ) {
2202+ return !isReplaceable (entry.getValueDecl ());
2203+ });
2204+
2205+ for (const auto &entry: replacementLookup.innerResults ()) {
2206+ lookup.add (entry, /* isOuter=*/ false );
2207+ }
2208+ }
2209+
20052210// / Check if it is safe for the \c globalActor qualifier to be removed from
20062211// / \c ty, when the function value of that type is isolated to that actor.
20072212// /
@@ -3745,6 +3950,23 @@ namespace {
37453950 if (isolationExpr->getActor ())
37463951 return ;
37473952
3953+ // #isolation does not work within an `@_unsafeInheritExecutor` function.
3954+ if (auto func = enclosingUnsafeInheritsExecutor (getDeclContext ())) {
3955+ // This expression is always written as a macro #isolation in source,
3956+ // so find the enclosing macro expansion expression's location.
3957+ SourceLoc diagLoc;
3958+ bool inDefaultArgument;
3959+ std::tie (diagLoc, inDefaultArgument) = adjustPoundIsolationDiagLoc (
3960+ isolationExpr, getDeclContext ()->getParentModule ());
3961+
3962+ bool inConcurrencyModule = ::inConcurrencyModule (getDeclContext ());
3963+ auto diag = ctx.Diags .diagnose (diagLoc,
3964+ diag::isolation_in_inherits_executor,
3965+ inDefaultArgument);
3966+ diag.limitBehaviorIf (inConcurrencyModule, DiagnosticBehavior::Warning);
3967+ replaceUnsafeInheritExecutorWithDefaultedIsolationParam (func, diag);
3968+ }
3969+
37483970 auto loc = isolationExpr->getLoc ();
37493971 auto isolation = getActorIsolationOfContext (
37503972 const_cast <DeclContext *>(getDeclContext ()),
0 commit comments