108108#include " swift/AST/ProtocolConformance.h"
109109#include " swift/AST/TypeCheckRequests.h"
110110#include " swift/AST/Types.h"
111+ #include " swift/SIL/ApplySite.h"
111112#include " swift/SIL/PrettyStackTrace.h"
112113#include " swift/SIL/AbstractionPatternGenerators.h"
113114#include " swift/SIL/SILArgument.h"
@@ -6474,14 +6475,75 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
64746475 }
64756476
64766477 auto *linearMapArg = thunk->getArgumentsWithoutIndirectResults ().back ();
6477- auto *apply = thunkSGF.B .createApply (loc, linearMapArg, SubstitutionMap (),
6478- arguments);
64796478
6480- // Get return elements.
6481- SmallVector<SILValue, 4 > results;
64826479 // Extract all direct results.
64836480 SmallVector<SILValue, 4 > directResults;
6484- extractAllElements (apply, loc, thunkSGF.B , directResults);
6481+ FullApplySite fas;
6482+ if (fromType->isCoroutine ()) {
6483+ SmallVector<SILValue, 1 > yields;
6484+ // Start inner coroutine execution till the suspend point
6485+ SubstitutionMap subs = thunk->getForwardingSubstitutionMap ();
6486+ SILType substFnType = linearMapArg->getType ().substGenericArgs (
6487+ thunkSGF.getModule (), subs, thunk->getTypeExpansionContext ());
6488+ auto tokenAndCleanups = thunkSGF.emitBeginApplyWithRethrow (
6489+ loc, linearMapArg, substFnType,
6490+ SubstitutionMap (), arguments, yields);
6491+ auto token = std::get<0 >(tokenAndCleanups);
6492+ auto abortCleanup = std::get<1 >(tokenAndCleanups);
6493+ auto allocation = std::get<2 >(tokenAndCleanups);
6494+ auto deallocCleanup = std::get<3 >(tokenAndCleanups);
6495+
6496+ {
6497+ SmallVector<ManagedValue, 1 > yieldMVs;
6498+
6499+ // Prepare a destination for the unwind; use the current cleanup stack
6500+ // as the depth so that we branch right to it.
6501+ SILBasicBlock *unwindBB = thunkSGF.createBasicBlock (FunctionSection::Postmatter);
6502+ JumpDest unwindDest (unwindBB, thunkSGF.Cleanups .getCleanupsDepth (),
6503+ CleanupLocation (loc));
6504+
6505+ manageYields (thunkSGF, yields, substFnType.castTo <SILFunctionType>()->getYields (),
6506+ yieldMVs);
6507+
6508+ // Emit the yield.
6509+ thunkSGF.emitRawYield (loc, yieldMVs, unwindDest, /* unique*/ true );
6510+
6511+ // Emit the unwind block.
6512+ {
6513+ SILGenSavedInsertionPoint savedIP (thunkSGF, unwindBB,
6514+ FunctionSection::Postmatter);
6515+
6516+ // Emit all active cleanups.
6517+ thunkSGF.Cleanups .emitCleanupsForReturn (CleanupLocation (loc), IsForUnwind);
6518+ thunkSGF.B .createUnwind (loc);
6519+ }
6520+ }
6521+
6522+ // Kill the normal abort cleanup without emitting it.
6523+ thunkSGF.Cleanups .setCleanupState (abortCleanup, CleanupState::Dead);
6524+ if (allocation) {
6525+ thunkSGF.Cleanups .setCleanupState (deallocCleanup, CleanupState::Dead);
6526+ }
6527+
6528+ // End the inner coroutine normally.
6529+ auto resultTy =
6530+ thunk->mapTypeIntoContext (
6531+ fromType->getAllResultsSubstType (thunkSGF.getModule (),
6532+ thunkSGF.getTypeExpansionContext ()));
6533+ auto endApply =
6534+ thunkSGF.emitEndApplyWithRethrow (loc, token, allocation, resultTy);
6535+
6536+ extractAllElements (endApply, loc, thunkSGF.B , directResults);
6537+ fas = token->getParent <BeginApplyInst>();
6538+ } else {
6539+ auto *apply = thunkSGF.B .createApply (loc, linearMapArg, SubstitutionMap (),
6540+ arguments);
6541+ extractAllElements (apply, loc, thunkSGF.B , directResults);
6542+ fas = apply;
6543+ }
6544+
6545+ // Get return elements.
6546+ SmallVector<SILValue, 4 > results;
64856547
64866548 // Handle self reordering.
64876549 // For pullbacks: rotate direct results if self is direct.
@@ -6500,7 +6562,7 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
65006562 }
65016563
65026564 auto fromDirResultsIter = directResults.begin ();
6503- auto fromIndResultsIter = apply-> getIndirectSILResults ().begin ();
6565+ auto fromIndResultsIter = fas. getIndirectSILResults ().begin ();
65046566 auto toIndResultsIter = thunkIndirectResults.begin ();
65056567 // Reabstract results.
65066568 for (unsigned resIdx : range (toType->getNumResults ())) {
0 commit comments