Skip to content

Commit f6eb061

Browse files
committed
Add support for coroutine linear map thunks
1 parent a268f83 commit f6eb061

File tree

1 file changed

+68
-6
lines changed

1 file changed

+68
-6
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
#include "swift/AST/ProtocolConformance.h"
109109
#include "swift/AST/TypeCheckRequests.h"
110110
#include "swift/AST/Types.h"
111+
#include "swift/SIL/ApplySite.h"
111112
#include "swift/SIL/PrettyStackTrace.h"
112113
#include "swift/SIL/AbstractionPatternGenerators.h"
113114
#include "swift/SIL/SILArgument.h"
@@ -6474,14 +6475,75 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
64746475
}
64756476

64766477
auto *linearMapArg = thunk->getArgumentsWithoutIndirectResults().back();
6477-
auto *apply = thunkSGF.B.createApply(loc, linearMapArg, SubstitutionMap(),
6478-
arguments);
64796478

6480-
// Get return elements.
6481-
SmallVector<SILValue, 4> results;
64826479
// Extract all direct results.
64836480
SmallVector<SILValue, 4> directResults;
6484-
extractAllElements(apply, loc, thunkSGF.B, directResults);
6481+
FullApplySite fas;
6482+
if (fromType->isCoroutine()) {
6483+
SmallVector<SILValue, 1> yields;
6484+
// Start inner coroutine execution till the suspend point
6485+
SubstitutionMap subs = thunk->getForwardingSubstitutionMap();
6486+
SILType substFnType = linearMapArg->getType().substGenericArgs(
6487+
thunkSGF.getModule(), subs, thunk->getTypeExpansionContext());
6488+
auto tokenAndCleanups = thunkSGF.emitBeginApplyWithRethrow(
6489+
loc, linearMapArg, substFnType,
6490+
SubstitutionMap(), arguments, yields);
6491+
auto token = std::get<0>(tokenAndCleanups);
6492+
auto abortCleanup = std::get<1>(tokenAndCleanups);
6493+
auto allocation = std::get<2>(tokenAndCleanups);
6494+
auto deallocCleanup = std::get<3>(tokenAndCleanups);
6495+
6496+
{
6497+
SmallVector<ManagedValue, 1> yieldMVs;
6498+
6499+
// Prepare a destination for the unwind; use the current cleanup stack
6500+
// as the depth so that we branch right to it.
6501+
SILBasicBlock *unwindBB = thunkSGF.createBasicBlock(FunctionSection::Postmatter);
6502+
JumpDest unwindDest(unwindBB, thunkSGF.Cleanups.getCleanupsDepth(),
6503+
CleanupLocation(loc));
6504+
6505+
manageYields(thunkSGF, yields, substFnType.castTo<SILFunctionType>()->getYields(),
6506+
yieldMVs);
6507+
6508+
// Emit the yield.
6509+
thunkSGF.emitRawYield(loc, yieldMVs, unwindDest, /*unique*/ true);
6510+
6511+
// Emit the unwind block.
6512+
{
6513+
SILGenSavedInsertionPoint savedIP(thunkSGF, unwindBB,
6514+
FunctionSection::Postmatter);
6515+
6516+
// Emit all active cleanups.
6517+
thunkSGF.Cleanups.emitCleanupsForReturn(CleanupLocation(loc), IsForUnwind);
6518+
thunkSGF.B.createUnwind(loc);
6519+
}
6520+
}
6521+
6522+
// Kill the normal abort cleanup without emitting it.
6523+
thunkSGF.Cleanups.setCleanupState(abortCleanup, CleanupState::Dead);
6524+
if (allocation) {
6525+
thunkSGF.Cleanups.setCleanupState(deallocCleanup, CleanupState::Dead);
6526+
}
6527+
6528+
// End the inner coroutine normally.
6529+
auto resultTy =
6530+
thunk->mapTypeIntoContext(
6531+
fromType->getAllResultsSubstType(thunkSGF.getModule(),
6532+
thunkSGF.getTypeExpansionContext()));
6533+
auto endApply =
6534+
thunkSGF.emitEndApplyWithRethrow(loc, token, allocation, resultTy);
6535+
6536+
extractAllElements(endApply, loc, thunkSGF.B, directResults);
6537+
fas = token->getParent<BeginApplyInst>();
6538+
} else {
6539+
auto *apply = thunkSGF.B.createApply(loc, linearMapArg, SubstitutionMap(),
6540+
arguments);
6541+
extractAllElements(apply, loc, thunkSGF.B, directResults);
6542+
fas = apply;
6543+
}
6544+
6545+
// Get return elements.
6546+
SmallVector<SILValue, 4> results;
64856547

64866548
// Handle self reordering.
64876549
// For pullbacks: rotate direct results if self is direct.
@@ -6500,7 +6562,7 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
65006562
}
65016563

65026564
auto fromDirResultsIter = directResults.begin();
6503-
auto fromIndResultsIter = apply->getIndirectSILResults().begin();
6565+
auto fromIndResultsIter = fas.getIndirectSILResults().begin();
65046566
auto toIndResultsIter = thunkIndirectResults.begin();
65056567
// Reabstract results.
65066568
for (unsigned resIdx : range(toType->getNumResults())) {

0 commit comments

Comments
 (0)