@@ -549,7 +549,7 @@ class PullbackCloner::Implementation final
549549 if (auto adjProj = getAdjointProjection (origBB, originalValue))
550550 return (bufferMap[{origBB, originalValue}] = adjProj);
551551
552- LLVM_DEBUG (getADDebugStream () << " Creating new adjoint buffer for"
552+ LLVM_DEBUG (getADDebugStream () << " Creating new adjoint buffer for "
553553 << originalValue
554554 << " in bb" << origBB->getDebugID () << ' \n ' );
555555
@@ -589,7 +589,8 @@ class PullbackCloner::Implementation final
589589 auto adjointBuffer = getAdjointBuffer (origBB, originalValue);
590590
591591 LLVM_DEBUG (getADDebugStream () << " Adding"
592- << rhsAddress << " to adjoint of "
592+ << rhsAddress << " to adjoint ("
593+ << adjointBuffer << " ) of "
593594 << originalValue
594595 << " in bb" << origBB->getDebugID () << ' \n ' );
595596
@@ -811,7 +812,8 @@ class PullbackCloner::Implementation final
811812#endif
812813 SILInstructionVisitor::visit (inst);
813814 LLVM_DEBUG ({
814- auto &s = llvm::dbgs () << " [ADJ] Emitted in pullback:\n " ;
815+ auto &s = llvm::dbgs () << " [ADJ] Emitted in pullback (pb bb" <<
816+ builder.getInsertionBB ()->getDebugID () << " ):\n " ;
815817 auto afterInsertion = builder.getInsertionPoint ();
816818 for (auto it = ++beforeInsertion; it != afterInsertion; ++it)
817819 s << *it;
@@ -1645,7 +1647,7 @@ class PullbackCloner::Implementation final
16451647 void
16461648 visitUncheckedTakeEnumDataAddrInst (UncheckedTakeEnumDataAddrInst *utedai) {
16471649 auto *bb = utedai->getParent ();
1648- auto adjBuf = getAdjointBuffer (bb, utedai);
1650+ auto adjDest = getAdjointBuffer (bb, utedai);
16491651 auto enumTy = utedai->getOperand ()->getType ();
16501652 auto *optionalEnumDecl = getASTContext ().getOptionalDecl ();
16511653 // Only `Optional`-typed operands are supported for now. Diagnose all other
@@ -1659,7 +1661,8 @@ class PullbackCloner::Implementation final
16591661 errorOccurred = true ;
16601662 return ;
16611663 }
1662- accumulateAdjointForOptional (bb, utedai->getOperand (), adjBuf);
1664+ accumulateAdjointForOptional (bb, utedai->getOperand (), adjDest);
1665+ builder.emitZeroIntoBuffer (utedai->getLoc (), adjDest, IsNotInitialization);
16631666 }
16641667
16651668#define NOT_DIFFERENTIABLE (INST, DIAG ) void visit##INST##Inst(INST##Inst *inst);
@@ -2473,6 +2476,10 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
24732476 for (auto *bbArg : bb->getArguments ()) {
24742477 if (!getActivityInfo ().isActive (bbArg, getConfig ()))
24752478 continue ;
2479+ LLVM_DEBUG (getADDebugStream () << " Propagating adjoint value for active bb"
2480+ << bb->getDebugID () << " argument: "
2481+ << *bbArg);
2482+
24762483 // Get predecessor terminator operands.
24772484 SmallVector<std::pair<SILBasicBlock *, SILValue>, 4 > incomingValues;
24782485 bbArg->getSingleTerminatorOperands (incomingValues);
0 commit comments