@@ -905,6 +905,7 @@ class PullbackCloner::Implementation final
905905 bool runForSemanticMemberAccessor ();
906906 bool runForSemanticMemberGetter ();
907907 bool runForSemanticMemberSetter ();
908+ bool runForSemanticMemberModify ();
908909
909910 // / If original result is non-varied, it will always have a zero derivative.
910911 // / Skip full pullback generation and simply emit zero derivatives for wrt
@@ -2452,7 +2453,8 @@ bool PullbackCloner::Implementation::run() {
24522453
24532454 // If the original function is an accessor with special-case pullback
24542455 // generation logic, do special-case generation.
2455- if (isSemanticMemberAccessor (&original)) {
2456+ bool isSemanticMemberAcc = isSemanticMemberAccessor (&original);
2457+ if (isSemanticMemberAcc) {
24562458 if (runForSemanticMemberAccessor ())
24572459 return true ;
24582460 }
@@ -2730,7 +2732,8 @@ bool PullbackCloner::Implementation::run() {
27302732#endif
27312733
27322734 LLVM_DEBUG (getADDebugStream ()
2733- << " Generated pullback for " << original.getName () << " :\n "
2735+ << " Generated " << (isSemanticMemberAcc ? " semantic member accessor" : " normal" )
2736+ << " pullback for " << original.getName () << " :\n "
27342737 << pullback);
27352738 return errorOccurred;
27362739}
@@ -3205,7 +3208,8 @@ bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
32053208 return runForSemanticMemberGetter ();
32063209 case AccessorKind::Set:
32073210 return runForSemanticMemberSetter ();
3208- // TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
3211+ case AccessorKind::Modify:
3212+ return runForSemanticMemberModify ();
32093213 default :
32103214 llvm_unreachable (" Unsupported accessor kind; inconsistent with "
32113215 " `isSemanticMemberAccessor`?" );
@@ -3389,6 +3393,83 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
33893393 return false ;
33903394}
33913395
3396+ bool PullbackCloner::Implementation::runForSemanticMemberModify () {
3397+ auto &original = getOriginal ();
3398+ auto &pullback = getPullback ();
3399+ auto pbLoc = getPullback ().getLocation ();
3400+
3401+ auto *accessor = cast<AccessorDecl>(original.getDeclContext ()->getAsDecl ());
3402+ assert (accessor->getAccessorKind () == AccessorKind::Modify);
3403+
3404+ auto *origEntry = original.getEntryBlock ();
3405+ // We assume that the accessor has a simple 3-BB structure with yield in the entry BB
3406+ // plus resume and unwind BBs
3407+ auto *yi = cast<YieldInst>(origEntry->getTerminator ());
3408+ auto *origResumeBB = yi->getResumeBB ();
3409+
3410+ auto *pbEntry = pullback.getEntryBlock ();
3411+ builder.setCurrentDebugScope (
3412+ remapScope (origEntry->getScopeOfFirstNonMetaInstruction ()));
3413+ builder.setInsertionPoint (pbEntry);
3414+
3415+ // Get _modify accessor argument values.
3416+ // Accessor type : $(inout Self) -> @yields @inout Argument
3417+ // Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument'
3418+ // Normally pullbacks for semantic member accessors are single BB and
3419+ // therefore have empty linear map tuple, however, coroutines have a branching
3420+ // control flow due to possible coroutine abort, so we need to accommodate for
3421+ // this. We keep branch tracing enums in order not to special case in many
3422+ // other places. As there is no way to return to coroutine via abort exit, we
3423+ // essentially "linearize" a coroutine.
3424+ auto loweredFnTy = original.getLoweredFunctionType ();
3425+ auto pullbackLoweredFnTy = pullback.getLoweredFunctionType ();
3426+
3427+ assert (loweredFnTy->getNumParameters () == 1 &&
3428+ loweredFnTy->getNumYields () == 1 );
3429+ assert (pullbackLoweredFnTy->getNumParameters () == 2 );
3430+ assert (pullbackLoweredFnTy->getNumYields () == 1 );
3431+
3432+ SILValue origSelf = original.getArgumentsWithoutIndirectResults ().front ();
3433+
3434+ SmallVector<SILValue, 8 > origFormalResults;
3435+ collectAllFormalResultsInTypeOrder (original, origFormalResults);
3436+
3437+ assert (getConfig ().resultIndices ->getNumIndices () == 2 &&
3438+ " Modify accessor should have two semantic results" );
3439+
3440+ auto origYield = origFormalResults[*std::next (getConfig ().resultIndices ->begin ())];
3441+
3442+ // Look up the corresponding field in the tangent space.
3443+ auto *origField = cast<VarDecl>(accessor->getStorage ());
3444+ auto baseType = remapType (origSelf->getType ()).getASTType ();
3445+ auto *tanField = getTangentStoredProperty (getContext (), origField, baseType,
3446+ pbLoc, getInvoker ());
3447+ if (!tanField) {
3448+ errorOccurred = true ;
3449+ return true ;
3450+ }
3451+
3452+ auto adjSelf = getAdjointBuffer (origResumeBB, origSelf);
3453+ auto *adjSelfElt = builder.createStructElementAddr (pbLoc, adjSelf, tanField);
3454+ // Modify accessors have inout yields and therefore should yield addresses.
3455+ assert (getTangentValueCategory (origYield) == SILValueCategory::Address &&
3456+ " Modify accessors should yield indirect" );
3457+
3458+ // Yield the adjoint buffer and do everything else in the resume
3459+ // destination. Unwind destination is unreachable as the coroutine can never
3460+ // be aborted.
3461+ auto *unwindBB = getPullback ().createBasicBlock ();
3462+ auto *resumeBB = getPullbackBlock (origEntry);
3463+ builder.createYield (yi->getLoc (), {adjSelfElt}, resumeBB, unwindBB);
3464+ builder.setInsertionPoint (unwindBB);
3465+ builder.createUnreachable (SILLocation::invalid ());
3466+
3467+ builder.setInsertionPoint (resumeBB);
3468+ addToAdjointBuffer (origEntry, origSelf, adjSelf, pbLoc);
3469+
3470+ return false ;
3471+ }
3472+
33923473// --------------------------------------------------------------------------//
33933474// Adjoint buffer mapping
33943475// --------------------------------------------------------------------------//
0 commit comments