@@ -2819,14 +2819,15 @@ namespace {
28192819 CanSILFunctionType loweredType, SubstitutionMap subs) {
28202820 LoweredInfos = loweredType->getUnsubstitutedType (SGM.M )->getYields ();
28212821
2822- auto accessor = cast<AccessorDecl >(function.getDecl ());
2823- auto storage = accessor-> getStorage ();
2822+ auto origFd = cast<FuncDecl >(function.getDecl ());
2823+ auto sig = origFd-> getGenericSignatureOfContext (). getCanonicalSignature ();
28242824
2825- OrigTypes.push_back (
2826- SGM.Types .getAbstractionPattern (storage, /* nonobjc*/ true ));
2825+ auto origYieldType = origFd->getYieldsInterfaceType ()->castTo <YieldResultType>();
2826+ auto reducedYieldType = sig.getReducedType (origYieldType->getResultType ());
2827+ OrigTypes.emplace_back (sig, reducedYieldType);
28272828
28282829 SmallVector<AnyFunctionType::Yield, 1 > yieldsBuffer;
2829- auto yields = AnyFunctionRef (accessor ).getYieldResults (yieldsBuffer);
2830+ auto yields = AnyFunctionRef (origFd ).getYieldResults (yieldsBuffer);
28302831 assert (yields.size () == 1 );
28312832 Yields.push_back (yields[0 ].getCanonical ().subst (subs).asParam ());
28322833 }
@@ -6359,10 +6360,6 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
63596360 SILType substFnType = fnRef->getType ().substGenericArgs (
63606361 M, subs, thunk->getTypeExpansionContext ());
63616362
6362- // Apply function argument.
6363- auto apply =
6364- thunkSGF.emitApplyWithRethrow (loc, fnRef, substFnType, subs, arguments);
6365-
63666363 // Self reordering thunk is necessary if wrt at least two parameters,
63676364 // including self.
63686365 auto shouldReorderSelf = [&]() {
@@ -6399,18 +6396,66 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
63996396 thunkSGF.B .createReturn (loc, retValue);
64006397 };
64016398
6402- if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
6403- SmallVector<SILValue, 8 > results;
6399+ SmallVector<SILValue, 8 > results;
6400+ if (customDerivativeFnTy->isCoroutine ()) {
6401+ assert (kind == AutoDiffDerivativeFunctionKind::VJP &&
6402+ " only support VJP custom coroutine derivatives" );
6403+
6404+ SmallVector<SILValue, 1 > yields;
6405+ // Start inner coroutine execution till the suspend point
6406+ auto tokenAndCleanups = thunkSGF.emitBeginApplyWithRethrow (
6407+ loc, fnRef, substFnType /* fnRef->getType()*/ ,
6408+ subs, arguments, yields);
6409+ auto token = std::get<0 >(tokenAndCleanups);
6410+ auto abortCleanup = std::get<1 >(tokenAndCleanups);
6411+ auto allocation = std::get<2 >(tokenAndCleanups);
6412+ auto deallocCleanup = std::get<3 >(tokenAndCleanups);
6413+
6414+ // Forward yields
6415+ auto *customDerivativeAFD =
6416+ cast<AbstractFunctionDecl>(customDerivativeFn->getDeclContext ()->getAsDecl ());
6417+ auto thunkTy = thunkSGF.F .getLoweredFunctionType ();
6418+ YieldInfo innerYieldInfo (*this , SILDeclRef (customDerivativeAFD), fnRefType,
6419+ subs);
6420+ // FIXME: We do not have Decl for the thunk as it is generated entirely at SIL level.
6421+ // Fix the yield info in case when reabstraction of yields would be required
6422+ YieldInfo outerYieldInfo (*this , SILDeclRef (customDerivativeAFD), thunkTy,
6423+ thunk->getForwardingSubstitutionMap ());
6424+ translateYields (thunkSGF, loc, yields, innerYieldInfo, outerYieldInfo);
6425+
6426+ // Kill the normal abort cleanup without emitting it. translateYields() will
6427+ // produce proper abort_apply & cleanups for inner coroutine call in the
6428+ // unwind block, and for normal return we're doing it manually below
6429+ thunkSGF.Cleanups .setCleanupState (abortCleanup, CleanupState::Dead);
6430+ if (allocation) {
6431+ thunkSGF.Cleanups .setCleanupState (deallocCleanup, CleanupState::Dead);
6432+ }
6433+
6434+ // End the inner coroutine normally.
6435+ auto resultTy =
6436+ thunk->mapTypeIntoContext (
6437+ fnRefType->getAllResultsSubstType (M,
6438+ thunkSGF.getTypeExpansionContext ()));
6439+ auto endApply =
6440+ thunkSGF.emitEndApplyWithRethrow (loc, token, allocation, resultTy);
6441+
6442+ extractAllElements (endApply, loc, thunkSGF.B , results);
6443+ } else {
6444+ // Apply function argument.
6445+ auto apply =
6446+ thunkSGF.emitApplyWithRethrow (loc, fnRef, substFnType, subs, arguments);
6447+
64046448 extractAllElements (apply, loc, thunkSGF.B , results);
6405- auto result = joinElements (results, thunkSGF.B , apply.getLoc ());
6449+ }
6450+
6451+ if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
6452+ auto result = joinElements (results, thunkSGF.B , loc);
64066453 createReturn (result);
64076454 return thunk;
64086455 }
64096456
64106457 // Otherwise, apply reabstraction/self reordering thunk to linear map.
6411- SmallVector<SILValue, 8 > directResults;
6412- extractAllElements (apply, loc, thunkSGF.B , directResults);
6413- auto linearMap = thunkSGF.emitManagedRValueWithCleanup (directResults.back ());
6458+ auto linearMap = thunkSGF.emitManagedRValueWithCleanup (results.back ());
64146459 assert (linearMap.getType ().castTo <SILFunctionType>() == linearMapFnType);
64156460 auto linearMapKind = kind.getLinearMapKind ();
64166461 linearMap = thunkSGF.getThunkedAutoDiffLinearMap (
@@ -6440,10 +6485,10 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
64406485 }
64416486
64426487 // Return original results and thunked differential/pullback.
6443- if (directResults .size () > 1 ) {
6444- auto originalDirectResults = ArrayRef<SILValue>(directResults ).drop_back (1 );
6488+ if (results .size () > 1 ) {
6489+ auto originalDirectResults = ArrayRef<SILValue>(results ).drop_back (1 );
64456490 auto originalDirectResult =
6446- joinElements (originalDirectResults, thunkSGF.B , apply. getLoc () );
6491+ joinElements (originalDirectResults, thunkSGF.B , loc );
64476492 auto thunkResult = joinElements (
64486493 {originalDirectResult, linearMap.forward (thunkSGF)}, thunkSGF.B , loc);
64496494 createReturn (thunkResult);
0 commit comments