1515//
1616// ===----------------------------------------------------------------------===//
1717
18- #include " swift/Basic/STLExtras.h"
1918#define DEBUG_TYPE " differentiation"
2019
2120#include " swift/SILOptimizer/Differentiation/PullbackCloner.h"
3130#include " swift/AST/PropertyWrappers.h"
3231#include " swift/AST/TypeCheckRequests.h"
3332#include " swift/Basic/Assertions.h"
33+ #include " swift/Basic/STLExtras.h"
3434#include " swift/SIL/ApplySite.h"
3535#include " swift/SIL/InstructionUtils.h"
3636#include " swift/SIL/Projection.h"
@@ -131,6 +131,10 @@ class PullbackCloner::Implementation final
131131 // / Stack buffers allocated for storing local adjoint values.
132132 SmallVector<AllocStackInst *, 64 > functionLocalAllocations;
133133
134+ // / Copies created to deal with destructive enum operations
135+ // / (unchecked_take_enum_addr)
136+ llvm::SmallDenseMap<InitEnumDataAddrInst*, SILValue> enumDataAdjCopies;
137+
134138 // / A set used to remember local allocations that were destroyed.
135139 llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;
136140
@@ -1858,7 +1862,7 @@ class PullbackCloner::Implementation final
18581862 // / Handle a sequence of `init_enum_data_addr` and `inject_enum_addr`
18591863 // / instructions.
18601864 // /
1861- // / Original: y = init_enum_data_addr x
1865+ // / Original: x = init_enum_data_addr y : $*Enum, #Enum.Case
18621866 // / inject_enum_addr y
18631867 // /
18641868 // / Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y]
@@ -1879,6 +1883,10 @@ class PullbackCloner::Implementation final
18791883 return ;
18801884 }
18811885
1886+ // No associated value => no adjoint to propagate
1887+ if (!inject->getElement ()->hasAssociatedValues ())
1888+ return ;
1889+
18821890 InitEnumDataAddrInst *origData = nullptr ;
18831891 for (auto use : origEnum->getUses ()) {
18841892 if (auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser ())) {
@@ -1900,9 +1908,9 @@ class PullbackCloner::Implementation final
19001908 }
19011909 }
19021910
1903- SILValue adjStruct = getAdjointBuffer (bb, origEnum);
1911+ SILValue adjDest = getAdjointBuffer (bb, origEnum);
19041912 StructDecl *adjStructDecl =
1905- adjStruct ->getType ().getStructOrBoundGenericStruct ();
1913+ adjDest ->getType ().getStructOrBoundGenericStruct ();
19061914
19071915 VarDecl *adjOptVar = nullptr ;
19081916 if (adjStructDecl) {
@@ -1922,35 +1930,35 @@ class PullbackCloner::Implementation final
19221930
19231931 SILLocation loc = origData->getLoc ();
19241932 StructElementAddrInst *adjOpt =
1925- builder.createStructElementAddr (loc, adjStruct , adjOptVar);
1933+ builder.createStructElementAddr (loc, adjDest , adjOptVar);
19261934
19271935 // unchecked_take_enum_data_addr is destructive, so copy
19281936 // Optional<T.TangentVector> to a new alloca.
19291937 AllocStackInst *adjOptCopy =
19301938 createFunctionLocalAllocation (adjOpt->getType (), loc);
19311939 builder.createCopyAddr (loc, adjOpt, adjOptCopy, IsNotTake,
19321940 IsInitialization);
1941+ // The Optional copy is invalidated, do not attempt to destroy it at the end
1942+ // of the pullback. The value returned from unchecked_take_enum_data_addr is
1943+ // destroyed in visitInitEnumDataAddrInst.
1944+ auto [_, inserted] = enumDataAdjCopies.try_emplace (origData, adjOptCopy);
1945+ assert (inserted && " expected single buffer" );
19331946
19341947 EnumElementDecl *someElemDecl = getASTContext ().getOptionalSomeDecl ();
19351948 UncheckedTakeEnumDataAddrInst *adjData =
19361949 builder.createUncheckedTakeEnumDataAddr (loc, adjOptCopy, someElemDecl);
19371950
1938- setAdjointBuffer (bb, origData, adjData);
1939-
1940- // The Optional copy is invalidated, do not attempt to destroy it at the end
1941- // of the pullback. The value returned from unchecked_take_enum_data_addr is
1942- // destroyed in visitInitEnumDataAddrInst.
1943- destroyedLocalAllocations.insert (adjOptCopy);
1951+ addToAdjointBuffer (bb, origData, adjData, loc);
19441952 }
19451953
19461954 // / Handle `init_enum_data_addr` instruction.
19471955 // / Destroy the value returned from `unchecked_take_enum_data_addr`.
19481956 void visitInitEnumDataAddrInst (InitEnumDataAddrInst *init) {
1949- auto bufIt = bufferMap. find ({ init-> getParent (), SILValue (init)} );
1950- if (bufIt == bufferMap. end ())
1951- return ;
1952- SILValue adjData = bufIt-> second ;
1953- builder. emitDestroyAddr (init-> getLoc (), adjData );
1957+ SILValue adjOptCopy = enumDataAdjCopies. at ( init);
1958+
1959+ builder. emitDestroyAddr (init-> getLoc (), adjOptCopy) ;
1960+ destroyedLocalAllocations. insert (adjOptCopy) ;
1961+ enumDataAdjCopies. erase (init);
19541962 }
19551963
19561964 // / Handle `unchecked_ref_cast` instruction.
@@ -2567,6 +2575,12 @@ bool PullbackCloner::Implementation::run() {
25672575 }
25682576 }
25692577 }
2578+ // Ensure all enum adjoint copeis have been cleaned up
2579+ for (const auto &enumData : enumDataAdjCopies) {
2580+ leakFound = true ;
2581+ getADDebugStream () << " Found leaked temporary:\n " << enumData.second ;
2582+ }
2583+
25702584 // Ensure all local allocations have been cleaned up.
25712585 for (auto localAlloc : functionLocalAllocations) {
25722586 if (!destroyedLocalAllocations.count (localAlloc)) {
0 commit comments