@@ -1736,13 +1736,21 @@ class PullbackCloner::Implementation final
17361736 // / move_value, begin_borrow.
17371737 // / Original: y = copy_value x
17381738 // / Adjoint: adj[x] += adj[y]
1739- void visitValueOwnershipInst (SingleValueInstruction *svi) {
1739+ void visitValueOwnershipInst (SingleValueInstruction *svi,
1740+ bool needZeroResAdj = false ) {
17401741 assert (svi->getNumOperands () == 1 );
17411742 auto *bb = svi->getParent ();
17421743 switch (getTangentValueCategory (svi)) {
17431744 case SILValueCategory::Object: {
17441745 auto adj = getAdjointValue (bb, svi);
17451746 addAdjointValue (bb, svi->getOperand (0 ), adj, svi->getLoc ());
1747+ if (needZeroResAdj) {
1748+ assert (svi->getNumResults () == 1 );
1749+ SILValue val = svi->getResult (0 );
1750+ setAdjointValue (
1751+ bb, val,
1752+ makeZeroAdjointValue (getRemappedTangentType (val->getType ())));
1753+ }
17461754 break ;
17471755 }
17481756 case SILValueCategory::Address: {
@@ -1768,8 +1776,16 @@ class PullbackCloner::Implementation final
17681776
17691777 // / Handle `move_value` instruction.
17701778 // / Original: y = move_value x
1771- // / Adjoint: adj[x] += adj[y]
1772- void visitMoveValueInst (MoveValueInst *mvi) { visitValueOwnershipInst (mvi); }
1779+ // / Adjoint: adj[x] += adj[y]; adj[y] = 0
1780+ void visitMoveValueInst (MoveValueInst *mvi) {
1781+ switch (getTangentValueCategory (mvi)) {
1782+ case SILValueCategory::Address:
1783+ llvm::report_fatal_error (" AutoDiff does not support move_value with "
1784+ " SILValueCategory::Address" );
1785+ case SILValueCategory::Object:
1786+ visitValueOwnershipInst (mvi, /* needZeroResAdj=*/ true );
1787+ }
1788+ }
17731789
17741790 void visitEndInitLetRefInst (EndInitLetRefInst *eir) { visitValueOwnershipInst (eir); }
17751791
0 commit comments