Skip to content

Commit 83ba984

Browse files
committed
Emit custom derivative thunks for coroutines
1 parent 16fef90 commit 83ba984

File tree

1 file changed

+63
-18
lines changed

1 file changed

+63
-18
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2819,14 +2819,15 @@ namespace {
28192819
CanSILFunctionType loweredType, SubstitutionMap subs) {
28202820
LoweredInfos = loweredType->getUnsubstitutedType(SGM.M)->getYields();
28212821

2822-
auto accessor = cast<AccessorDecl>(function.getDecl());
2823-
auto storage = accessor->getStorage();
2822+
auto origFd = cast<FuncDecl>(function.getDecl());
2823+
auto sig = origFd->getGenericSignatureOfContext().getCanonicalSignature();
28242824

2825-
OrigTypes.push_back(
2826-
SGM.Types.getAbstractionPattern(storage, /*nonobjc*/ true));
2825+
auto origYieldType = origFd->getYieldsInterfaceType()->castTo<YieldResultType>();
2826+
auto reducedYieldType = sig.getReducedType(origYieldType->getResultType());
2827+
OrigTypes.emplace_back(sig, reducedYieldType);
28272828

28282829
SmallVector<AnyFunctionType::Yield, 1> yieldsBuffer;
2829-
auto yields = AnyFunctionRef(accessor).getYieldResults(yieldsBuffer);
2830+
auto yields = AnyFunctionRef(origFd).getYieldResults(yieldsBuffer);
28302831
assert(yields.size() == 1);
28312832
Yields.push_back(yields[0].getCanonical().subst(subs).asParam());
28322833
}
@@ -6359,10 +6360,6 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
63596360
SILType substFnType = fnRef->getType().substGenericArgs(
63606361
M, subs, thunk->getTypeExpansionContext());
63616362

6362-
// Apply function argument.
6363-
auto apply =
6364-
thunkSGF.emitApplyWithRethrow(loc, fnRef, substFnType, subs, arguments);
6365-
63666363
// Self reordering thunk is necessary if wrt at least two parameters,
63676364
// including self.
63686365
auto shouldReorderSelf = [&]() {
@@ -6399,18 +6396,66 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
63996396
thunkSGF.B.createReturn(loc, retValue);
64006397
};
64016398

6402-
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
6403-
SmallVector<SILValue, 8> results;
6399+
SmallVector<SILValue, 8> results;
6400+
if (customDerivativeFnTy->isCoroutine()) {
6401+
assert(kind == AutoDiffDerivativeFunctionKind::VJP &&
6402+
"only support VJP custom coroutine derivatives");
6403+
6404+
SmallVector<SILValue, 1> yields;
6405+
// Start inner coroutine execution till the suspend point
6406+
auto tokenAndCleanups = thunkSGF.emitBeginApplyWithRethrow(
6407+
loc, fnRef, substFnType /*fnRef->getType()*/,
6408+
subs, arguments, yields);
6409+
auto token = std::get<0>(tokenAndCleanups);
6410+
auto abortCleanup = std::get<1>(tokenAndCleanups);
6411+
auto allocation = std::get<2>(tokenAndCleanups);
6412+
auto deallocCleanup = std::get<3>(tokenAndCleanups);
6413+
6414+
// Forward yields
6415+
auto *customDerivativeAFD =
6416+
cast<AbstractFunctionDecl>(customDerivativeFn->getDeclContext()->getAsDecl());
6417+
auto thunkTy = thunkSGF.F.getLoweredFunctionType();
6418+
YieldInfo innerYieldInfo(*this, SILDeclRef(customDerivativeAFD), fnRefType,
6419+
subs);
6420+
// FIXME: We do not have Decl for the thunk as it is generated entirely at SIL level.
6421+
// Fix the yield info in case when reabstraction of yields would be required
6422+
YieldInfo outerYieldInfo(*this, SILDeclRef(customDerivativeAFD), thunkTy,
6423+
thunk->getForwardingSubstitutionMap());
6424+
translateYields(thunkSGF, loc, yields, innerYieldInfo, outerYieldInfo);
6425+
6426+
// Kill the normal abort cleanup without emitting it. translateYields() will
6427+
// produce proper abort_apply & cleanups for inner coroutine call in the
6428+
// unwind block, and for normal return we're doing it manually below
6429+
thunkSGF.Cleanups.setCleanupState(abortCleanup, CleanupState::Dead);
6430+
if (allocation) {
6431+
thunkSGF.Cleanups.setCleanupState(deallocCleanup, CleanupState::Dead);
6432+
}
6433+
6434+
// End the inner coroutine normally.
6435+
auto resultTy =
6436+
thunk->mapTypeIntoContext(
6437+
fnRefType->getAllResultsSubstType(M,
6438+
thunkSGF.getTypeExpansionContext()));
6439+
auto endApply =
6440+
thunkSGF.emitEndApplyWithRethrow(loc, token, allocation, resultTy);
6441+
6442+
extractAllElements(endApply, loc, thunkSGF.B, results);
6443+
} else {
6444+
// Apply function argument.
6445+
auto apply =
6446+
thunkSGF.emitApplyWithRethrow(loc, fnRef, substFnType, subs, arguments);
6447+
64046448
extractAllElements(apply, loc, thunkSGF.B, results);
6405-
auto result = joinElements(results, thunkSGF.B, apply.getLoc());
6449+
}
6450+
6451+
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
6452+
auto result = joinElements(results, thunkSGF.B, loc);
64066453
createReturn(result);
64076454
return thunk;
64086455
}
64096456

64106457
// Otherwise, apply reabstraction/self reordering thunk to linear map.
6411-
SmallVector<SILValue, 8> directResults;
6412-
extractAllElements(apply, loc, thunkSGF.B, directResults);
6413-
auto linearMap = thunkSGF.emitManagedRValueWithCleanup(directResults.back());
6458+
auto linearMap = thunkSGF.emitManagedRValueWithCleanup(results.back());
64146459
assert(linearMap.getType().castTo<SILFunctionType>() == linearMapFnType);
64156460
auto linearMapKind = kind.getLinearMapKind();
64166461
linearMap = thunkSGF.getThunkedAutoDiffLinearMap(
@@ -6440,10 +6485,10 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
64406485
}
64416486

64426487
// Return original results and thunked differential/pullback.
6443-
if (directResults.size() > 1) {
6444-
auto originalDirectResults = ArrayRef<SILValue>(directResults).drop_back(1);
6488+
if (results.size() > 1) {
6489+
auto originalDirectResults = ArrayRef<SILValue>(results).drop_back(1);
64456490
auto originalDirectResult =
6446-
joinElements(originalDirectResults, thunkSGF.B, apply.getLoc());
6491+
joinElements(originalDirectResults, thunkSGF.B, loc);
64476492
auto thunkResult = joinElements(
64486493
{originalDirectResult, linearMap.forward(thunkSGF)}, thunkSGF.B, loc);
64496494
createReturn(thunkResult);

0 commit comments

Comments
 (0)