105105#include " swift/AST/ProtocolConformance.h"
106106#include " swift/AST/TypeCheckRequests.h"
107107#include " swift/AST/Types.h"
108+ #include " swift/SIL/ApplySite.h"
108109#include " swift/SIL/PrettyStackTrace.h"
109110#include " swift/SIL/AbstractionPatternGenerators.h"
110111#include " swift/SIL/SILArgument.h"
@@ -6185,14 +6186,75 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
61856186 }
61866187
61876188 auto *linearMapArg = thunk->getArgumentsWithoutIndirectResults ().back ();
6188- auto *apply = thunkSGF.B .createApply (loc, linearMapArg, SubstitutionMap (),
6189- arguments);
61906189
6191- // Get return elements.
6192- SmallVector<SILValue, 4 > results;
61936190 // Extract all direct results.
61946191 SmallVector<SILValue, 4 > directResults;
6195- extractAllElements (apply, loc, thunkSGF.B , directResults);
6192+ FullApplySite fas;
6193+ if (fromType->isCoroutine ()) {
6194+ SmallVector<SILValue, 1 > yields;
6195+ // Start inner coroutine execution till the suspend point
6196+ SubstitutionMap subs = thunk->getForwardingSubstitutionMap ();
6197+ SILType substFnType = linearMapArg->getType ().substGenericArgs (
6198+ thunkSGF.getModule (), subs, thunk->getTypeExpansionContext ());
6199+ auto tokenAndCleanups = thunkSGF.emitBeginApplyWithRethrow (
6200+ loc, linearMapArg, substFnType,
6201+ SubstitutionMap (), arguments, yields);
6202+ auto token = std::get<0 >(tokenAndCleanups);
6203+ auto abortCleanup = std::get<1 >(tokenAndCleanups);
6204+ auto allocation = std::get<2 >(tokenAndCleanups);
6205+ auto deallocCleanup = std::get<3 >(tokenAndCleanups);
6206+
6207+ {
6208+ SmallVector<ManagedValue, 1 > yieldMVs;
6209+
6210+ // Prepare a destination for the unwind; use the current cleanup stack
6211+ // as the depth so that we branch right to it.
6212+ SILBasicBlock *unwindBB = thunkSGF.createBasicBlock (FunctionSection::Postmatter);
6213+ JumpDest unwindDest (unwindBB, thunkSGF.Cleanups .getCleanupsDepth (),
6214+ CleanupLocation (loc));
6215+
6216+ manageYields (thunkSGF, yields, substFnType.castTo <SILFunctionType>()->getYields (),
6217+ yieldMVs);
6218+
6219+ // Emit the yield.
6220+ thunkSGF.emitRawYield (loc, yieldMVs, unwindDest, /* unique*/ true );
6221+
6222+ // Emit the unwind block.
6223+ {
6224+ SILGenSavedInsertionPoint savedIP (thunkSGF, unwindBB,
6225+ FunctionSection::Postmatter);
6226+
6227+ // Emit all active cleanups.
6228+ thunkSGF.Cleanups .emitCleanupsForReturn (CleanupLocation (loc), IsForUnwind);
6229+ thunkSGF.B .createUnwind (loc);
6230+ }
6231+ }
6232+
6233+ // Kill the normal abort cleanup without emitting it.
6234+ thunkSGF.Cleanups .setCleanupState (abortCleanup, CleanupState::Dead);
6235+ if (allocation) {
6236+ thunkSGF.Cleanups .setCleanupState (deallocCleanup, CleanupState::Dead);
6237+ }
6238+
6239+ // End the inner coroutine normally.
6240+ auto resultTy =
6241+ thunk->mapTypeIntoContext (
6242+ fromType->getAllResultsSubstType (thunkSGF.getModule (),
6243+ thunkSGF.getTypeExpansionContext ()));
6244+ auto endApply =
6245+ thunkSGF.emitEndApplyWithRethrow (loc, token, allocation, resultTy);
6246+
6247+ extractAllElements (endApply, loc, thunkSGF.B , directResults);
6248+ fas = token->getParent <BeginApplyInst>();
6249+ } else {
6250+ auto *apply = thunkSGF.B .createApply (loc, linearMapArg, SubstitutionMap (),
6251+ arguments);
6252+ extractAllElements (apply, loc, thunkSGF.B , directResults);
6253+ fas = apply;
6254+ }
6255+
6256+ // Get return elements.
6257+ SmallVector<SILValue, 4 > results;
61966258
61976259 // Handle self reordering.
61986260 // For pullbacks: rotate direct results if self is direct.
@@ -6211,7 +6273,7 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
62116273 }
62126274
62136275 auto fromDirResultsIter = directResults.begin ();
6214- auto fromIndResultsIter = apply-> getIndirectSILResults ().begin ();
6276+ auto fromIndResultsIter = fas. getIndirectSILResults ().begin ();
62156277 auto toIndResultsIter = thunkIndirectResults.begin ();
62166278 // Reabstract results.
62176279 for (unsigned resIdx : range (toType->getNumResults ())) {
0 commit comments