Skip to content

Commit a268f83

Browse files
committed
Emit custom derivative thunks for coroutines
1 parent d2289c5 commit a268f83

File tree

1 file changed

+63
-18
lines changed

1 file changed

+63
-18
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2996,14 +2996,15 @@ namespace {
29962996
CanSILFunctionType loweredType, SubstitutionMap subs) {
29972997
LoweredInfos = loweredType->getUnsubstitutedType(SGM.M)->getYields();
29982998

2999-
auto accessor = cast<AccessorDecl>(function.getDecl());
3000-
auto storage = accessor->getStorage();
2999+
auto origFd = cast<FuncDecl>(function.getDecl());
3000+
auto sig = origFd->getGenericSignatureOfContext().getCanonicalSignature();
30013001

3002-
OrigTypes.push_back(
3003-
SGM.Types.getAbstractionPattern(storage, /*nonobjc*/ true));
3002+
auto origYieldType = origFd->getYieldsInterfaceType()->castTo<YieldResultType>();
3003+
auto reducedYieldType = sig.getReducedType(origYieldType->getResultType());
3004+
OrigTypes.emplace_back(sig, reducedYieldType);
30043005

30053006
SmallVector<AnyFunctionType::Yield, 1> yieldsBuffer;
3006-
auto yields = AnyFunctionRef(accessor).getYieldResults(yieldsBuffer);
3007+
auto yields = AnyFunctionRef(origFd).getYieldResults(yieldsBuffer);
30073008
assert(yields.size() == 1);
30083009
Yields.push_back(yields[0].getCanonical().subst(subs).asParam());
30093010
}
@@ -6655,10 +6656,6 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
66556656
SILType substFnType = fnRef->getType().substGenericArgs(
66566657
M, subs, thunk->getTypeExpansionContext());
66576658

6658-
// Apply function argument.
6659-
auto apply =
6660-
thunkSGF.emitApplyWithRethrow(loc, fnRef, substFnType, subs, arguments);
6661-
66626659
// Self reordering thunk is necessary if wrt at least two parameters,
66636660
// including self.
66646661
auto shouldReorderSelf = [&]() {
@@ -6695,18 +6692,66 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
66956692
thunkSGF.B.createReturn(loc, retValue);
66966693
};
66976694

6698-
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
6699-
SmallVector<SILValue, 8> results;
6695+
SmallVector<SILValue, 8> results;
6696+
if (customDerivativeFnTy->isCoroutine()) {
6697+
assert(kind == AutoDiffDerivativeFunctionKind::VJP &&
6698+
"only support VJP custom coroutine derivatives");
6699+
6700+
SmallVector<SILValue, 1> yields;
6701+
// Start inner coroutine execution till the suspend point
6702+
auto tokenAndCleanups = thunkSGF.emitBeginApplyWithRethrow(
6703+
loc, fnRef, substFnType /*fnRef->getType()*/,
6704+
subs, arguments, yields);
6705+
auto token = std::get<0>(tokenAndCleanups);
6706+
auto abortCleanup = std::get<1>(tokenAndCleanups);
6707+
auto allocation = std::get<2>(tokenAndCleanups);
6708+
auto deallocCleanup = std::get<3>(tokenAndCleanups);
6709+
6710+
// Forward yields
6711+
auto *customDerivativeAFD =
6712+
cast<AbstractFunctionDecl>(customDerivativeFn->getDeclContext()->getAsDecl());
6713+
auto thunkTy = thunkSGF.F.getLoweredFunctionType();
6714+
YieldInfo innerYieldInfo(*this, SILDeclRef(customDerivativeAFD), fnRefType,
6715+
subs);
6716+
// FIXME: We do not have Decl for the thunk as it is generated entirely at SIL level.
6717+
// Fix the yield info in case when reabstraction of yields would be required
6718+
YieldInfo outerYieldInfo(*this, SILDeclRef(customDerivativeAFD), thunkTy,
6719+
thunk->getForwardingSubstitutionMap());
6720+
translateYields(thunkSGF, loc, yields, innerYieldInfo, outerYieldInfo);
6721+
6722+
// Kill the normal abort cleanup without emitting it. translateYields() will
6723+
// produce proper abort_apply & cleanups for inner coroutine call in the
6724+
// unwind block, and for normal return we're doing it manually below
6725+
thunkSGF.Cleanups.setCleanupState(abortCleanup, CleanupState::Dead);
6726+
if (allocation) {
6727+
thunkSGF.Cleanups.setCleanupState(deallocCleanup, CleanupState::Dead);
6728+
}
6729+
6730+
// End the inner coroutine normally.
6731+
auto resultTy =
6732+
thunk->mapTypeIntoContext(
6733+
fnRefType->getAllResultsSubstType(M,
6734+
thunkSGF.getTypeExpansionContext()));
6735+
auto endApply =
6736+
thunkSGF.emitEndApplyWithRethrow(loc, token, allocation, resultTy);
6737+
6738+
extractAllElements(endApply, loc, thunkSGF.B, results);
6739+
} else {
6740+
// Apply function argument.
6741+
auto apply =
6742+
thunkSGF.emitApplyWithRethrow(loc, fnRef, substFnType, subs, arguments);
6743+
67006744
extractAllElements(apply, loc, thunkSGF.B, results);
6701-
auto result = joinElements(results, thunkSGF.B, apply.getLoc());
6745+
}
6746+
6747+
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
6748+
auto result = joinElements(results, thunkSGF.B, loc);
67026749
createReturn(result);
67036750
return thunk;
67046751
}
67056752

67066753
// Otherwise, apply reabstraction/self reordering thunk to linear map.
6707-
SmallVector<SILValue, 8> directResults;
6708-
extractAllElements(apply, loc, thunkSGF.B, directResults);
6709-
auto linearMap = thunkSGF.emitManagedRValueWithCleanup(directResults.back());
6754+
auto linearMap = thunkSGF.emitManagedRValueWithCleanup(results.back());
67106755
assert(linearMap.getType().castTo<SILFunctionType>() == linearMapFnType);
67116756
auto linearMapKind = kind.getLinearMapKind();
67126757
linearMap = thunkSGF.getThunkedAutoDiffLinearMap(
@@ -6736,10 +6781,10 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
67366781
}
67376782

67386783
// Return original results and thunked differential/pullback.
6739-
if (directResults.size() > 1) {
6740-
auto originalDirectResults = ArrayRef<SILValue>(directResults).drop_back(1);
6784+
if (results.size() > 1) {
6785+
auto originalDirectResults = ArrayRef<SILValue>(results).drop_back(1);
67416786
auto originalDirectResult =
6742-
joinElements(originalDirectResults, thunkSGF.B, apply.getLoc());
6787+
joinElements(originalDirectResults, thunkSGF.B, loc);
67436788
auto thunkResult = joinElements(
67446789
{originalDirectResult, linearMap.forward(thunkSGF)}, thunkSGF.B, loc);
67456790
createReturn(thunkResult);

0 commit comments

Comments
 (0)