@@ -105,9 +105,9 @@ import SILBridging
105105
106106private let verbose = false
107107
108- private func log( _ message: @autoclosure ( ) -> String ) {
108+ private func log( prefix : Bool = true , _ message: @autoclosure ( ) -> String ) {
109109 if verbose {
110- print ( " ### \( message ( ) ) " )
110+ debugLog ( prefix : prefix , message ( ) )
111111 }
112112}
113113
@@ -128,47 +128,48 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
128128 }
129129
130130 var remainingSpecializationRounds = 5
131- var callerModified = false
132131
133132 repeat {
133+ // TODO: Names here are pretty misleading. We are looking for a place where
134+ // the pullback closure is created (so for `partial_apply` instruction).
134135 var callSites = gatherCallSites ( in: function, context)
136+ guard !callSites. isEmpty else {
137+ return
138+ }
135139
136- if !callSites. isEmpty {
137- for callSite in callSites {
138- var ( specializedFunction, alreadyExists) = getOrCreateSpecializedFunction ( basedOn: callSite, context)
139-
140- if !alreadyExists {
141- context. notifyNewFunction ( function: specializedFunction, derivedFrom: callSite. applyCallee)
142- }
140+ for callSite in callSites {
141+ var ( specializedFunction, alreadyExists) = getOrCreateSpecializedFunction ( basedOn: callSite, context)
143142
144- rewriteApplyInstruction ( using: specializedFunction, callSite: callSite, context)
143+ if !alreadyExists {
144+ context. notifyNewFunction ( function: specializedFunction, derivedFrom: callSite. applyCallee)
145145 }
146146
147- var deadClosures : InstructionWorklist = callSites. reduce ( into: InstructionWorklist ( context) ) { deadClosures, callSite in
148- callSite. closureArgDescriptors
149- . map { $0. closure }
150- . forEach { deadClosures. pushIfNotVisited ( $0) }
151- }
147+ rewriteApplyInstruction ( using: specializedFunction, callSite: callSite, context)
148+ }
152149
153- defer {
154- deadClosures. deinitialize ( )
155- }
150+ var deadClosures : InstructionWorklist = callSites. reduce ( into: InstructionWorklist ( context) ) { deadClosures, callSite in
151+ callSite. closureArgDescriptors
152+ . map { $0. closure }
153+ . forEach { deadClosures. pushIfNotVisited ( $0) }
154+ }
156155
157- while let deadClosure = deadClosures. pop ( ) {
158- let isDeleted = context. tryDeleteDeadClosure ( closure: deadClosure as! SingleValueInstruction )
159- if isDeleted {
160- context. notifyInvalidatedStackNesting ( )
161- }
162- }
156+ defer {
157+ deadClosures. deinitialize ( )
158+ }
163159
164- if context. needFixStackNesting {
165- function. fixStackNesting ( context)
160+ while let deadClosure = deadClosures. pop ( ) {
161+ let isDeleted = context. tryDeleteDeadClosure ( closure: deadClosure as! SingleValueInstruction )
162+ if isDeleted {
163+ context. notifyInvalidatedStackNesting ( )
166164 }
167165 }
168166
169- callerModified = callSites. count > 0
167+ if context. needFixStackNesting {
168+ function. fixStackNesting ( context)
169+ }
170+
170171 remainingSpecializationRounds -= 1
171- } while callerModified && remainingSpecializationRounds > 0
172+ } while remainingSpecializationRounds > 0
172173}
173174
174175// =========== Top-level functions ========== //
@@ -503,12 +504,6 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap:
503504 continue
504505 }
505506
506- // Workaround for a problem with OSSA: https://github.com/swiftlang/swift/issues/78847
507- // TODO: remove this if-statement once the underlying problem is fixed.
508- if callee. hasOwnership {
509- continue
510- }
511-
512507 if callee. isDefinedExternally {
513508 continue
514509 }
@@ -779,13 +774,13 @@ private extension SpecializationCloner {
779774
780775 let clonedRootClosure = builder. cloneRootClosure ( representedBy: closureArgDesc, capturedArguments: clonedClosureArgs)
781776
782- let ( finalClonedReabstractedClosure, releasableClonedReabstractedClosures ) =
777+ let finalClonedReabstractedClosure =
783778 builder. cloneRootClosureReabstractions ( rootClosure: closureArgDesc. closure, clonedRootClosure: clonedRootClosure,
784779 reabstractedClosure: callSite. appliedArgForClosure ( at: closureArgDesc. closureArgIndex) !,
785780 origToClonedValueMap: origToClonedValueMap,
786781 self . context)
787782
788- let allClonedReleasableClosures = [ clonedRootClosure ] + releasableClonedReabstractedClosures
783+ let allClonedReleasableClosures = [ finalClonedReabstractedClosure ] ;
789784 return ( finalClonedReabstractedClosure, allClonedReleasableClosures)
790785 }
791786
@@ -935,10 +930,9 @@ private extension Builder {
935930
936931 func cloneRootClosureReabstractions( rootClosure: Value , clonedRootClosure: Value , reabstractedClosure: Value ,
937932 origToClonedValueMap: [ HashableValue : Value ] , _ context: FunctionPassContext )
938- -> ( finalClonedReabstractedClosure : SingleValueInstruction , releasableClonedReabstractedClosures : [ PartialApplyInst ] )
933+ -> SingleValueInstruction
939934 {
940935 func inner( _ rootClosure: Value , _ clonedRootClosure: Value , _ reabstractedClosure: Value ,
941- _ releasableClonedReabstractedClosures: inout [ PartialApplyInst ] ,
942936 _ origToClonedValueMap: inout [ HashableValue : Value ] ) -> Value {
943937 switch reabstractedClosure {
944938 case let reabstractedClosure where reabstractedClosure == rootClosure:
@@ -947,23 +941,23 @@ private extension Builder {
947941
948942 case let cvt as ConvertFunctionInst :
949943 let toBeReabstracted = inner ( rootClosure, clonedRootClosure, cvt. fromFunction,
950- & releasableClonedReabstractedClosures , & origToClonedValueMap)
944+ & origToClonedValueMap)
951945 let reabstracted = self . createConvertFunction ( originalFunction: toBeReabstracted, resultType: cvt. type,
952946 withoutActuallyEscaping: cvt. withoutActuallyEscaping)
953947 origToClonedValueMap [ cvt] = reabstracted
954948 return reabstracted
955949
956950 case let cvt as ConvertEscapeToNoEscapeInst :
957951 let toBeReabstracted = inner ( rootClosure, clonedRootClosure, cvt. fromFunction,
958- & releasableClonedReabstractedClosures , & origToClonedValueMap)
952+ & origToClonedValueMap)
959953 let reabstracted = self . createConvertEscapeToNoEscape ( originalFunction: toBeReabstracted, resultType: cvt. type,
960954 isLifetimeGuaranteed: true )
961955 origToClonedValueMap [ cvt] = reabstracted
962956 return reabstracted
963957
964958 case let pai as PartialApplyInst :
965959 let toBeReabstracted = inner ( rootClosure, clonedRootClosure, pai. arguments [ 0 ] ,
966- & releasableClonedReabstractedClosures , & origToClonedValueMap)
960+ & origToClonedValueMap)
967961
968962 guard let function = pai. referencedFunction else {
969963 log ( " Parent function of callSite: \( rootClosure. parentFunction) " )
@@ -978,13 +972,11 @@ private extension Builder {
978972 calleeConvention: pai. calleeConvention,
979973 hasUnknownResultIsolation: pai. hasUnknownResultIsolation,
980974 isOnStack: pai. isOnStack)
981- releasableClonedReabstractedClosures. append ( reabstracted)
982975 origToClonedValueMap [ pai] = reabstracted
983976 return reabstracted
984977
985978 case let mdi as MarkDependenceInst :
986- let toBeReabstracted = inner ( rootClosure, clonedRootClosure, mdi. value, & releasableClonedReabstractedClosures,
987- & origToClonedValueMap)
979+ let toBeReabstracted = inner ( rootClosure, clonedRootClosure, mdi. value, & origToClonedValueMap)
988980 let base = origToClonedValueMap [ mdi. base] !
989981 let reabstracted = self . createMarkDependence ( value: toBeReabstracted, base: base, kind: . Escaping)
990982 origToClonedValueMap [ mdi] = reabstracted
@@ -998,11 +990,10 @@ private extension Builder {
998990 }
999991 }
1000992
1001- var releasableClonedReabstractedClosures : [ PartialApplyInst ] = [ ]
1002993 var origToClonedValueMap = origToClonedValueMap
1003994 let finalClonedReabstractedClosure = inner ( rootClosure, clonedRootClosure, reabstractedClosure,
1004- & releasableClonedReabstractedClosures , & origToClonedValueMap)
1005- return ( finalClonedReabstractedClosure as! SingleValueInstruction , releasableClonedReabstractedClosures )
995+ & origToClonedValueMap)
996+ return ( finalClonedReabstractedClosure as! SingleValueInstruction )
1006997 }
1007998
1008999 func destroyPartialApply( pai: PartialApplyInst , _ context: FunctionPassContext ) {
0 commit comments