@@ -560,7 +560,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
560560 SILOptFunctionBuilder &fb, SILFunction *parentThunk,
561561 CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
562562 CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
563- AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig,
563+ const AutoDiffConfig & desiredConfig, const AutoDiffConfig & actualConfig,
564564 ADContext &adContext) {
565565 LLVM_DEBUG (getADDebugStream ()
566566 << " Getting a subset parameters thunk for " << linearMapType
@@ -592,8 +592,6 @@ getOrCreateSubsetParametersThunkForLinearMap(
592592 if (!thunk->empty ())
593593 return {thunk, interfaceSubs};
594594
595- // TODO(TF-1206): Enable ownership in all differentiation thunks.
596- thunk->setOwnershipEliminated ();
597595 thunk->setGenericEnvironment (genericEnv);
598596 auto *entry = thunk->createBasicBlock ();
599597 TangentBuilder builder (entry, adContext);
@@ -602,6 +600,14 @@ getOrCreateSubsetParametersThunkForLinearMap(
602600 // Get arguments.
603601 SmallVector<SILValue, 4 > arguments;
604602 SmallVector<AllocStackInst *, 4 > localAllocations;
603+ SmallVector<SILValue, 4 > valuesToCleanup;
604+ auto cleanupValues = [&]() {
605+ for (auto value : llvm::reverse (valuesToCleanup))
606+ builder.emitDestroyOperation (loc, value);
607+
608+ for (auto *alloc : llvm::reverse (localAllocations))
609+ builder.createDeallocStack (loc, alloc);
610+ };
605611
606612 // Build a `.zero` argument for the given `Differentiable`-conforming type.
607613 auto buildZeroArgument = [&](SILType zeroSILType) {
@@ -617,10 +623,12 @@ getOrCreateSubsetParametersThunkForLinearMap(
617623 localAllocations.push_back (buf);
618624 builder.emitZeroIntoBuffer (loc, buf, IsInitialization);
619625 if (zeroSILType.isAddress ()) {
626+ valuesToCleanup.push_back (buf);
620627 arguments.push_back (buf);
621628 } else {
622629 auto arg = builder.emitLoadValueOperation (loc, buf,
623630 LoadOwnershipQualifier::Take);
631+ valuesToCleanup.push_back (arg);
624632 arguments.push_back (arg);
625633 }
626634 break ;
@@ -739,8 +747,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
739747 // If differential thunk, deallocate local allocations and directly return
740748 // `apply` result.
741749 if (kind == AutoDiffDerivativeFunctionKind::JVP) {
742- for (auto *alloc : llvm::reverse (localAllocations))
743- builder.createDeallocStack (loc, alloc);
750+ cleanupValues ();
744751 builder.createReturn (loc, ai);
745752 return {thunk, interfaceSubs};
746753 }
@@ -787,8 +794,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
787794 }
788795 }
789796 // Deallocate local allocations and return final direct result.
790- for (auto *alloc : llvm::reverse (localAllocations))
791- builder.createDeallocStack (loc, alloc);
797+ cleanupValues ();
792798 auto result = joinElements (results, builder, loc);
793799 builder.createReturn (loc, result);
794800
0 commit comments