Skip to content

Commit 7cef5be

Browse files
committed
Add support for coroutine linear map thunks
1 parent 83ba984 commit 7cef5be

File tree

1 file changed

+68
-6
lines changed

1 file changed

+68
-6
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
#include "swift/AST/ProtocolConformance.h"
106106
#include "swift/AST/TypeCheckRequests.h"
107107
#include "swift/AST/Types.h"
108+
#include "swift/SIL/ApplySite.h"
108109
#include "swift/SIL/PrettyStackTrace.h"
109110
#include "swift/SIL/AbstractionPatternGenerators.h"
110111
#include "swift/SIL/SILArgument.h"
@@ -6185,14 +6186,75 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
61856186
}
61866187

61876188
auto *linearMapArg = thunk->getArgumentsWithoutIndirectResults().back();
6188-
auto *apply = thunkSGF.B.createApply(loc, linearMapArg, SubstitutionMap(),
6189-
arguments);
61906189

6191-
// Get return elements.
6192-
SmallVector<SILValue, 4> results;
61936190
// Extract all direct results.
61946191
SmallVector<SILValue, 4> directResults;
6195-
extractAllElements(apply, loc, thunkSGF.B, directResults);
6192+
FullApplySite fas;
6193+
if (fromType->isCoroutine()) {
6194+
SmallVector<SILValue, 1> yields;
6195+
// Start inner coroutine execution till the suspend point
6196+
SubstitutionMap subs = thunk->getForwardingSubstitutionMap();
6197+
SILType substFnType = linearMapArg->getType().substGenericArgs(
6198+
thunkSGF.getModule(), subs, thunk->getTypeExpansionContext());
6199+
auto tokenAndCleanups = thunkSGF.emitBeginApplyWithRethrow(
6200+
loc, linearMapArg, substFnType,
6201+
SubstitutionMap(), arguments, yields);
6202+
auto token = std::get<0>(tokenAndCleanups);
6203+
auto abortCleanup = std::get<1>(tokenAndCleanups);
6204+
auto allocation = std::get<2>(tokenAndCleanups);
6205+
auto deallocCleanup = std::get<3>(tokenAndCleanups);
6206+
6207+
{
6208+
SmallVector<ManagedValue, 1> yieldMVs;
6209+
6210+
// Prepare a destination for the unwind; use the current cleanup stack
6211+
// as the depth so that we branch right to it.
6212+
SILBasicBlock *unwindBB = thunkSGF.createBasicBlock(FunctionSection::Postmatter);
6213+
JumpDest unwindDest(unwindBB, thunkSGF.Cleanups.getCleanupsDepth(),
6214+
CleanupLocation(loc));
6215+
6216+
manageYields(thunkSGF, yields, substFnType.castTo<SILFunctionType>()->getYields(),
6217+
yieldMVs);
6218+
6219+
// Emit the yield.
6220+
thunkSGF.emitRawYield(loc, yieldMVs, unwindDest, /*unique*/ true);
6221+
6222+
// Emit the unwind block.
6223+
{
6224+
SILGenSavedInsertionPoint savedIP(thunkSGF, unwindBB,
6225+
FunctionSection::Postmatter);
6226+
6227+
// Emit all active cleanups.
6228+
thunkSGF.Cleanups.emitCleanupsForReturn(CleanupLocation(loc), IsForUnwind);
6229+
thunkSGF.B.createUnwind(loc);
6230+
}
6231+
}
6232+
6233+
// Kill the normal abort cleanup without emitting it.
6234+
thunkSGF.Cleanups.setCleanupState(abortCleanup, CleanupState::Dead);
6235+
if (allocation) {
6236+
thunkSGF.Cleanups.setCleanupState(deallocCleanup, CleanupState::Dead);
6237+
}
6238+
6239+
// End the inner coroutine normally.
6240+
auto resultTy =
6241+
thunk->mapTypeIntoContext(
6242+
fromType->getAllResultsSubstType(thunkSGF.getModule(),
6243+
thunkSGF.getTypeExpansionContext()));
6244+
auto endApply =
6245+
thunkSGF.emitEndApplyWithRethrow(loc, token, allocation, resultTy);
6246+
6247+
extractAllElements(endApply, loc, thunkSGF.B, directResults);
6248+
fas = token->getParent<BeginApplyInst>();
6249+
} else {
6250+
auto *apply = thunkSGF.B.createApply(loc, linearMapArg, SubstitutionMap(),
6251+
arguments);
6252+
extractAllElements(apply, loc, thunkSGF.B, directResults);
6253+
fas = apply;
6254+
}
6255+
6256+
// Get return elements.
6257+
SmallVector<SILValue, 4> results;
61966258

61976259
// Handle self reordering.
61986260
// For pullbacks: rotate direct results if self is direct.
@@ -6211,7 +6273,7 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
62116273
}
62126274

62136275
auto fromDirResultsIter = directResults.begin();
6214-
auto fromIndResultsIter = apply->getIndirectSILResults().begin();
6276+
auto fromIndResultsIter = fas.getIndirectSILResults().begin();
62156277
auto toIndResultsIter = thunkIndirectResults.begin();
62166278
// Reabstract results.
62176279
for (unsigned resIdx : range(toType->getNumResults())) {

0 commit comments

Comments
 (0)