@@ -2996,14 +2996,15 @@ namespace {
29962996 CanSILFunctionType loweredType, SubstitutionMap subs) {
29972997 LoweredInfos = loweredType->getUnsubstitutedType (SGM.M )->getYields ();
29982998
2999- auto accessor = cast<AccessorDecl >(function.getDecl ());
3000- auto storage = accessor-> getStorage ();
2999+ auto origFd = cast<FuncDecl >(function.getDecl ());
3000+ auto sig = origFd-> getGenericSignatureOfContext (). getCanonicalSignature ();
30013001
3002- OrigTypes.push_back (
3003- SGM.Types .getAbstractionPattern (storage, /* nonobjc*/ true ));
3002+ auto origYieldType = origFd->getYieldsInterfaceType ()->castTo <YieldResultType>();
3003+ auto reducedYieldType = sig.getReducedType (origYieldType->getResultType ());
3004+ OrigTypes.emplace_back (sig, reducedYieldType);
30043005
30053006 SmallVector<AnyFunctionType::Yield, 1 > yieldsBuffer;
3006- auto yields = AnyFunctionRef (accessor ).getYieldResults (yieldsBuffer);
3007+ auto yields = AnyFunctionRef (origFd ).getYieldResults (yieldsBuffer);
30073008 assert (yields.size () == 1 );
30083009 Yields.push_back (yields[0 ].getCanonical ().subst (subs).asParam ());
30093010 }
@@ -6655,10 +6656,6 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
66556656 SILType substFnType = fnRef->getType ().substGenericArgs (
66566657 M, subs, thunk->getTypeExpansionContext ());
66576658
6658- // Apply function argument.
6659- auto apply =
6660- thunkSGF.emitApplyWithRethrow (loc, fnRef, substFnType, subs, arguments);
6661-
66626659 // Self reordering thunk is necessary if wrt at least two parameters,
66636660 // including self.
66646661 auto shouldReorderSelf = [&]() {
@@ -6695,18 +6692,66 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
66956692 thunkSGF.B .createReturn (loc, retValue);
66966693 };
66976694
6698- if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
6699- SmallVector<SILValue, 8 > results;
6695+ SmallVector<SILValue, 8 > results;
6696+ if (customDerivativeFnTy->isCoroutine ()) {
6697+ assert (kind == AutoDiffDerivativeFunctionKind::VJP &&
6698+ " only support VJP custom coroutine derivatives" );
6699+
6700+ SmallVector<SILValue, 1 > yields;
6701+ // Start inner coroutine execution till the suspend point
6702+ auto tokenAndCleanups = thunkSGF.emitBeginApplyWithRethrow (
6703+ loc, fnRef, substFnType /* fnRef->getType()*/ ,
6704+ subs, arguments, yields);
6705+ auto token = std::get<0 >(tokenAndCleanups);
6706+ auto abortCleanup = std::get<1 >(tokenAndCleanups);
6707+ auto allocation = std::get<2 >(tokenAndCleanups);
6708+ auto deallocCleanup = std::get<3 >(tokenAndCleanups);
6709+
6710+ // Forward yields
6711+ auto *customDerivativeAFD =
6712+ cast<AbstractFunctionDecl>(customDerivativeFn->getDeclContext ()->getAsDecl ());
6713+ auto thunkTy = thunkSGF.F .getLoweredFunctionType ();
6714+ YieldInfo innerYieldInfo (*this , SILDeclRef (customDerivativeAFD), fnRefType,
6715+ subs);
6716+ // FIXME: We do not have Decl for the thunk as it is generated entirely at SIL level.
6717+ // Fix the yield info in case when reabstraction of yields would be required
6718+ YieldInfo outerYieldInfo (*this , SILDeclRef (customDerivativeAFD), thunkTy,
6719+ thunk->getForwardingSubstitutionMap ());
6720+ translateYields (thunkSGF, loc, yields, innerYieldInfo, outerYieldInfo);
6721+
6722+ // Kill the normal abort cleanup without emitting it. translateYields() will
6723+ // produce proper abort_apply & cleanups for inner coroutine call in the
6724+ // unwind block, and for normal return we're doing it manually below
6725+ thunkSGF.Cleanups .setCleanupState (abortCleanup, CleanupState::Dead);
6726+ if (allocation) {
6727+ thunkSGF.Cleanups .setCleanupState (deallocCleanup, CleanupState::Dead);
6728+ }
6729+
6730+ // End the inner coroutine normally.
6731+ auto resultTy =
6732+ thunk->mapTypeIntoContext (
6733+ fnRefType->getAllResultsSubstType (M,
6734+ thunkSGF.getTypeExpansionContext ()));
6735+ auto endApply =
6736+ thunkSGF.emitEndApplyWithRethrow (loc, token, allocation, resultTy);
6737+
6738+ extractAllElements (endApply, loc, thunkSGF.B , results);
6739+ } else {
6740+ // Apply function argument.
6741+ auto apply =
6742+ thunkSGF.emitApplyWithRethrow (loc, fnRef, substFnType, subs, arguments);
6743+
67006744 extractAllElements (apply, loc, thunkSGF.B , results);
6701- auto result = joinElements (results, thunkSGF.B , apply.getLoc ());
6745+ }
6746+
6747+ if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
6748+ auto result = joinElements (results, thunkSGF.B , loc);
67026749 createReturn (result);
67036750 return thunk;
67046751 }
67056752
67066753 // Otherwise, apply reabstraction/self reordering thunk to linear map.
6707- SmallVector<SILValue, 8 > directResults;
6708- extractAllElements (apply, loc, thunkSGF.B , directResults);
6709- auto linearMap = thunkSGF.emitManagedRValueWithCleanup (directResults.back ());
6754+ auto linearMap = thunkSGF.emitManagedRValueWithCleanup (results.back ());
67106755 assert (linearMap.getType ().castTo <SILFunctionType>() == linearMapFnType);
67116756 auto linearMapKind = kind.getLinearMapKind ();
67126757 linearMap = thunkSGF.getThunkedAutoDiffLinearMap (
@@ -6736,10 +6781,10 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
67366781 }
67376782
67386783 // Return original results and thunked differential/pullback.
6739- if (directResults .size () > 1 ) {
6740- auto originalDirectResults = ArrayRef<SILValue>(directResults ).drop_back (1 );
6784+ if (results .size () > 1 ) {
6785+ auto originalDirectResults = ArrayRef<SILValue>(results ).drop_back (1 );
67416786 auto originalDirectResult =
6742- joinElements (originalDirectResults, thunkSGF.B , apply. getLoc () );
6787+ joinElements (originalDirectResults, thunkSGF.B , loc );
67436788 auto thunkResult = joinElements (
67446789 {originalDirectResult, linearMap.forward (thunkSGF)}, thunkSGF.B , loc);
67456790 createReturn (thunkResult);
0 commit comments