@@ -765,14 +765,15 @@ class PullbackCloner::Implementation final
765765 SILValue wrappedAdjoint,
766766 SILType optionalTy);
767767
768- // / Accumulate optional buffer from `wrappedAdjoint`.
768+ // / Accumulate adjoint of `wrappedAdjoint` into optionalBuffer .
769769 void accumulateAdjointForOptionalBuffer (SILBasicBlock *bb,
770770 SILValue optionalBuffer,
771771 SILValue wrappedAdjoint);
772772
773- // / Set optional value from `wrappedAdjoint`.
774- void setAdjointValueForOptional (SILBasicBlock *bb, SILValue optionalValue,
775- SILValue wrappedAdjoint);
773+ // / Accumulate adjoint of `wrappedAdjoint` into optionalValue.
774+ void accumulateAdjointValueForOptional (SILBasicBlock *bb,
775+ SILValue optionalValue,
776+ SILValue wrappedAdjoint);
776777
777778 // --------------------------------------------------------------------------//
778779 // Array literal initialization differentiation
@@ -2732,8 +2733,8 @@ void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer(
27322733 builder.createDeallocStack (pbLoc, optTanAdjBuf);
27332734}
27342735
2735- // Set the adjoint value for the incoming `Optional` value.
2736- void PullbackCloner::Implementation::setAdjointValueForOptional (
2736+ // Accumulate adjoint for the incoming `Optional` value.
2737+ void PullbackCloner::Implementation::accumulateAdjointValueForOptional (
27372738 SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
27382739 assert (getTangentValueCategory (optionalValue) == SILValueCategory::Object);
27392740 auto pbLoc = getPullback ().getLocation ();
@@ -2745,10 +2746,11 @@ void PullbackCloner::Implementation::setAdjointValueForOptional(
27452746
27462747 auto optTanAdjVal = builder.emitLoadValueOperation (
27472748 pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take);
2749+
27482750 recordTemporary (optTanAdjVal);
27492751 builder.createDeallocStack (pbLoc, optTanAdjBuf);
27502752
2751- setAdjointValue (bb, optionalValue, makeConcreteAdjointValue (optTanAdjVal));
2753+ addAdjointValue (bb, optionalValue, makeConcreteAdjointValue (optTanAdjVal), pbLoc );
27522754}
27532755
27542756SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor (
@@ -2959,12 +2961,12 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
29592961 // Handle `switch_enum` on `Optional`.
29602962 auto termInst = bbArg->getSingleTerminator ();
29612963 if (isSwitchEnumInstOnOptional (termInst)) {
2962- setAdjointValueForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2964+ accumulateAdjointValueForOptional (bb, incomingValue, concreteBBArgAdjCopy);
29632965 } else {
29642966 blockTemporaries[getPullbackBlock (predBB)].insert (
29652967 concreteBBArgAdjCopy);
2966- setAdjointValue (predBB, incomingValue,
2967- makeConcreteAdjointValue (concreteBBArgAdjCopy));
2968+ addAdjointValue (predBB, incomingValue,
2969+ makeConcreteAdjointValue (concreteBBArgAdjCopy), pbLoc );
29682970 }
29692971 }
29702972 break ;
0 commit comments