Skip to content

Commit df20d36

Browse files
committed
ClosureSpecialization: support for OSSA and a big overhaul
Beside supporting OSSA, this change significantly simplifies the pass. The main change is that instead of starting at a closure (e.g. `partial_apply`) and finding all call sites, we now start at a call site and look for closures for all arguments. This makes a lot of things much simpler, e.g. not so many intermediate data structures are required to track all the states. I needed to remove the 3 unit tests because the things those tests were testing are not there anymore. However, the pass is tested with a lot of sil tests (and I added quite a few), which should give good test coverage. The old ClosureSpecializer pass is still kept in place, because at that point in the pipeline we don't have OSSA, yet. Once we have that, we can replace the old pass withe the new one. However, the autodiff closure specializer already runs in the OSSA pipeline and there the new changes take effect.
1 parent 89bba66 commit df20d36

File tree

23 files changed

+1690
-2787
lines changed

23 files changed

+1690
-2787
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift

Lines changed: 512 additions & 1196 deletions
Large diffs are not rendered by default.

SwiftCompilerSources/Sources/Optimizer/PassManager/PassRegistration.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ private func registerSwiftPasses() {
105105
registerPass(tempRValueElimination, { tempRValueElimination.run($0) })
106106
registerPass(mandatoryTempRValueElimination, { mandatoryTempRValueElimination.run($0) })
107107
registerPass(tempLValueElimination, { tempLValueElimination.run($0) })
108-
registerPass(generalClosureSpecialization, { generalClosureSpecialization.run($0) })
108+
registerPass(closureSpecialization, { closureSpecialization.run($0) })
109109
registerPass(autodiffClosureSpecialization, { autodiffClosureSpecialization.run($0) })
110110
registerPass(loopInvariantCodeMotionPass, { loopInvariantCodeMotionPass.run($0) })
111111

SwiftCompilerSources/Sources/Optimizer/Utilities/FunctionTest.swift

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ public func registerOptimizerTests() {
4343
registerFunctionTests(
4444
addressOwnershipLiveRangeTest,
4545
argumentConventionsTest,
46-
getPullbackClosureInfoTest,
4746
interiorLivenessTest,
4847
lifetimeDependenceRootTest,
4948
lifetimeDependenceScopeTest,
@@ -52,8 +51,6 @@ public func registerOptimizerTests() {
5251
localVariableReachableUsesTest,
5352
localVariableReachingAssignmentsTest,
5453
rangeOverlapsPathTest,
55-
rewrittenCallerBodyTest,
56-
specializedFunctionSignatureAndBodyTest,
5754
variableIntroducerTest
5855
)
5956

include/swift/SILOptimizer/PassManager/Passes.def

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,8 @@ PASS(TempLValueElimination, "temp-lvalue-elimination",
153153
PASS(LoopInvariantCodeMotion, "loop-invariant-code-motion",
154154
"New Loop Invariant Code Motion")
155155

156-
// NOTE - ExperimentalSwiftBasedClosureSpecialization and AutodiffClosureSpecialization are a WIP
157-
PASS(ExperimentalSwiftBasedClosureSpecialization, "experimental-swift-based-closure-specialization",
158-
"General closure-specialization pass written in Swift")
156+
PASS(ClosureSpecialization, "closure-specialization",
157+
"Specialize functions with closure arguments")
159158
PASS(AutodiffClosureSpecialization, "autodiff-closure-specialization",
160159
"Autodiff specific closure-specialization pass")
161160

lib/SILOptimizer/PassManager/PassPipeline.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,10 +1034,10 @@ SILPassPipelinePlan::getPerformancePassPipeline(const SILOptions &Options) {
10341034
if (SILPrintFinalOSSAModule) {
10351035
addModulePrinterPipeline(P, "SIL Print Final OSSA Module");
10361036
}
1037-
P.addOwnershipModelEliminator();
1038-
10391037
P.addAutodiffClosureSpecialization();
10401038

1039+
P.addOwnershipModelEliminator();
1040+
10411041
// After serialization run the function pass pipeline to iteratively lower
10421042
// high-level constructs like @_semantics calls.
10431043
addMidLevelFunctionPipeline(P);

test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil

Lines changed: 65 additions & 161 deletions
Large diffs are not rendered by default.

test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_no_bte1.sil

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
/// Multi basic block VJP, pullback not accepting branch tracing enum argument.
22

3-
// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER,CHECK
43
// RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK
54

65
// REQUIRES: swift_in_compiler
@@ -82,51 +81,30 @@ sil @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (
8281
sil @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
8382

8483
// pullback of methodWrapper(_:)
85-
sil private [signature_optimized_thunk] [heuristic_always_inline] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector {
86-
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Class.TangentVector):
84+
sil private [signature_optimized_thunk] [heuristic_always_inline] [ossa] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector {
85+
bb0(%0 : $Float, %1 : @owned $@callee_guaranteed (Float) -> Class.TangentVector):
8786
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> Class.TangentVector
88-
strong_release %1
87+
destroy_value %1
8988
return %2
9089
} // end sil function '$s4test13methodWrapperySfAA5ClassVFTJpSpSr'
9190

9291
// reverse-mode derivative of methodWrapper(_:)
93-
sil hidden @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
92+
sil hidden [ossa] @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
9493
bb0(%0 : $Class):
95-
//=========== Test callsite and closure gathering logic ===========//
96-
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
97-
// TRUNNER-LABEL: Specializing closures in function: $s4test13methodWrapperySfAA5ClassVFTJrSpSr
98-
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#A36:]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector
99-
// TRUNNER-NEXT: Passed in closures:
100-
// TRUNNER-NEXT: 1. %[[#A36]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
101-
// TRUNNER-EMPTY:
102-
103-
//=========== Test specialized function signature and body ===========//
104-
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
105-
// TRUNNER-LABEL: Generated specialized function: $s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n
106-
// CHECK: sil private [signature_optimized_thunk] [heuristic_always_inline] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector {
107-
// CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
94+
// CHECK: sil private [signature_optimized_thunk] [heuristic_always_inline] [ossa] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector {
95+
// CHECK: bb0(%0 : $Float, %1 : @owned $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
10896
// CHECK: %[[#B2:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
109-
// TRUNNER: %[[#B3:]] = partial_apply [callee_guaranteed] %[[#B2]](%1) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
110-
// TRUNNER: %[[#B4:]] = apply %[[#B3]](%0) : $@callee_guaranteed (Float) -> Class.TangentVector
11197
// COMBINE-NOT: partial_apply
11298
// COMBINE: %[[#B4:]] = apply %[[#B2]](%0, %1) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
113-
// TRUNNER: strong_release %[[#B3]] : $@callee_guaranteed (Float) -> Class.TangentVector
11499
// CHECK: return %[[#B4]]
115100

116-
//=========== Test rewritten body ===========//
117-
specify_test "autodiff_closure_specialize_rewritten_caller_body"
118-
// TRUNNER-LABEL: Rewritten caller body for: $s4test13methodWrapperySfAA5ClassVFTJrSpSr:
119-
// CHECK: sil hidden @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
120-
// CHECK: bb3(%[[#C33:]] : $Float, %[[#C34:]] : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
121-
// TRUNNER: %[[#C35:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
122-
// TRUNNER: %[[#C37:]] = partial_apply [callee_guaranteed] %[[#C35]](%[[#C34]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
123-
// TRUNNER: %[[#C38:]] = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector
101+
// CHECK: sil hidden [ossa] @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
102+
// CHECK: bb3(%[[#C33:]] : $Float, %[[#C34:]] : @owned $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
124103
// COMBINE-NOT: function_ref @$s4test5ClassV6methodSfyFTJpSpSr
125104
// COMBINE-NOT: partial_apply
126105
// COMBINE-NOT: function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr
127106
// CHECK: %[[#C39:]] = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
128107
// CHECK: %[[#C40:]] = partial_apply [callee_guaranteed] %[[#C39]](%[[#C34]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
129-
// TRUNNER: release_value %[[#C37]] : $@callee_guaranteed (Float) -> Class.TangentVector
130108
// CHECK: %[[#C42:]] = tuple (%[[#C33]] : $Float, %[[#C40]] : $@callee_guaranteed (Float) -> Class.TangentVector)
131109
// CHECK: return %[[#C42]]
132110

@@ -169,7 +147,7 @@ bb2:
169147
%46 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %45
170148
br bb3(%42, %46)
171149

172-
bb3(%48 : $Float, %49 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
150+
bb3(%48 : $Float, %49 : @owned $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
173151
// function_ref pullback of Class.method()
174152
%50 = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
175153
%51 = partial_apply [callee_guaranteed] %50(%49) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector

0 commit comments

Comments
 (0)