4343// ===----------------------------------------------------------------------===//
4444
4545#define DEBUG_TYPE " sil-capture-promotion"
46+
4647#include " swift/AST/GenericEnvironment.h"
4748#include " swift/SIL/SILCloner.h"
49+ #include " swift/SIL/SILInstruction.h"
4850#include " swift/SIL/TypeSubstCloner.h"
4951#include " swift/SILOptimizer/PassManager/Passes.h"
5052#include " swift/SILOptimizer/PassManager/Transforms.h"
@@ -756,8 +758,12 @@ void ClosureCloner::visitLoadInst(LoadInst *li) {
756758namespace {
757759
758760struct EscapeMutationScanningState {
759- // / The list of mutations that we found while checking for escapes.
760- llvm::SmallVector<SILInstruction *, 8 > foundMutations;
761+ // / The list of mutations in the partial_apply caller that we found.
762+ SmallVector<Operand *, 8 > accumulatedMutations;
763+
764+ // / The list of escapes in the partial_apply caller/callee of the box that we
765+ // / found.
766+ SmallVector<Operand *, 8 > accumulatedEscapes;
761767
762768 // / A flag that we use to ensure that we only ever see 1 project_box on an
763769 // / alloc_box.
@@ -786,16 +792,20 @@ static bool isNonMutatingLoad(SILInstruction *inst) {
786792 return li->getOwnershipQualifier () != LoadOwnershipQualifier::Take;
787793}
788794
789- // / Given a partial_apply instruction and the argument index into its
790- // / callee's argument list of a box argument (which is followed by an argument
791- // / for the address of the box's contents), return true if the closure is known
792- // / not to mutate the captured variable.
793- static bool isNonMutatingCapture (SILArgument *boxArg) {
795+ // / Given a partial_apply instruction and the argument index into its callee's
796+ // / argument list of a box argument (which is followed by an argument for the
797+ // / address of the box's contents), return true if this box has mutating
798+ // / captures. Return false otherwise. All of the mutating captures that we find
799+ // / are placed into \p accumulatedMutatingUses.
800+ static bool getPartialApplyArgMutationsAndEscapes (
801+ SILArgument *boxArg, SmallVectorImpl<Operand *> &accumulatedMutatingUses,
802+ SmallVectorImpl<Operand *> &accumulatedEscapes) {
794803 SmallVector<ProjectBoxInst *, 2 > projectBoxInsts;
795804
796805 // Conservatively do not allow any use of the box argument other than a
797806 // strong_release or projection, since this is the pattern expected from
798807 // SILGen.
808+ SmallVector<Operand *, 32 > incrementalEscapes;
799809 for (auto *use : boxArg->getUses ()) {
800810 if (isa<StrongReleaseInst>(use->getUser ()) ||
801811 isa<DestroyValueInst>(use->getUser ()))
@@ -806,7 +816,7 @@ static bool isNonMutatingCapture(SILArgument *boxArg) {
806816 continue ;
807817 }
808818
809- return false ;
819+ incrementalEscapes. push_back (use) ;
810820 }
811821
812822 // Only allow loads of projections, either directly or via
@@ -815,33 +825,44 @@ static bool isNonMutatingCapture(SILArgument *boxArg) {
815825 // TODO: This seems overly limited. Why not projections of tuples and other
816826 // stuff? Also, why not recursive struct elements? This should be a helper
817827 // function that mirrors isNonEscapingUse.
818- auto isAddrUseMutating = [](SILInstruction *addrInst) {
828+ auto checkIfAddrUseMutating = [&](Operand *addrUse) -> bool {
829+ unsigned initSize = incrementalEscapes.size ();
830+ auto *addrInst = addrUse->getUser ();
819831 if (auto *seai = dyn_cast<StructElementAddrInst>(addrInst)) {
820- return all_of (seai->getUses (), [](Operand *op) -> bool {
821- return isNonMutatingLoad (op->getUser ());
822- });
832+ for (auto *seaiUse : seai->getUses ()) {
833+ if (!isNonMutatingLoad (seaiUse->getUser ())) {
834+ incrementalEscapes.push_back (seaiUse);
835+ }
836+ }
837+ return incrementalEscapes.size () != initSize;
823838 }
824839
825- return isNonMutatingLoad (addrInst) || isa<DebugValueAddrInst>(addrInst) ||
826- isa<MarkFunctionEscapeInst>(addrInst) ||
827- isa<EndAccessInst>(addrInst);
840+ if (isNonMutatingLoad (addrInst) || isa<DebugValueAddrInst>(addrInst) ||
841+ isa<MarkFunctionEscapeInst>(addrInst) || isa<EndAccessInst>(addrInst)) {
842+ return false ;
843+ }
844+
845+ incrementalEscapes.push_back (addrUse);
846+ return true ;
828847 };
829848
830849 for (auto *pbi : projectBoxInsts) {
831850 for (auto *use : pbi->getUses ()) {
832851 if (auto *bai = dyn_cast<BeginAccessInst>(use->getUser ())) {
833852 for (auto *accessUseOper : bai->getUses ()) {
834- if (!isAddrUseMutating (accessUseOper->getUser ()))
835- return false ;
853+ checkIfAddrUseMutating (accessUseOper);
836854 }
837855 continue ;
838856 }
839857
840- if (!isAddrUseMutating (use->getUser ()))
841- return false ;
858+ checkIfAddrUseMutating (use);
842859 }
843860 }
844861
862+ if (incrementalEscapes.empty ())
863+ return false ;
864+ while (!incrementalEscapes.empty ())
865+ accumulatedEscapes.push_back (incrementalEscapes.pop_back_val ());
845866 return true ;
846867}
847868
@@ -852,11 +873,14 @@ bool isPartialApplyNonEscapingUser(Operand *currentOp, PartialApplyInst *pai,
852873 unsigned opNo = currentOp->getOperandNumber ();
853874 assert (opNo != 0 && " Alloc box used as callee of partial apply?" );
854875
855- // If we've already seen this partial apply, then it means the same alloc
856- // box is being captured twice by the same closure, which is odd and
857- // unexpected: bail instead of trying to handle this case.
876+ // If we've already seen this partial apply, then it means the same alloc box
877+ // is being captured twice by the same closure, which is odd and unexpected:
878+ // bail instead of trying to handle this case.
858879 if (state.globalIndexMap .count (pai)) {
880+ // TODO: Is it correct to treat this like an escape? We are just currently
881+ // flagging all failures as warnings.
859882 LLVM_DEBUG (llvm::dbgs () << " FAIL! Already seen.\n " );
883+ state.accumulatedEscapes .push_back (currentOp);
860884 return false ;
861885 }
862886
@@ -877,6 +901,7 @@ bool isPartialApplyNonEscapingUser(Operand *currentOp, PartialApplyInst *pai,
877901 if (!fn || !fn->isDefinition () || fn->isDynamicallyReplaceable ()) {
878902 LLVM_DEBUG (llvm::dbgs () << " FAIL! Not a direct function definition "
879903 " reference.\n " );
904+ state.accumulatedEscapes .push_back (currentOp);
880905 return false ;
881906 }
882907
@@ -893,14 +918,17 @@ bool isPartialApplyNonEscapingUser(Operand *currentOp, PartialApplyInst *pai,
893918 .isAddressOnly (*f)) {
894919 LLVM_DEBUG (llvm::dbgs () << " FAIL! Box is an address only "
895920 " argument!\n " );
921+ state.accumulatedEscapes .push_back (currentOp);
896922 return false ;
897923 }
898924
899925 // Verify that this closure is known not to mutate the captured value; if
900926 // it does, then conservatively refuse to promote any captures of this
901927 // value.
902- if (!isNonMutatingCapture (boxArg)) {
903- LLVM_DEBUG (llvm::dbgs () << " FAIL: Have a mutating capture!\n " );
928+ if (getPartialApplyArgMutationsAndEscapes (boxArg, state.accumulatedMutations ,
929+ state.accumulatedEscapes )) {
930+ LLVM_DEBUG (llvm::dbgs () << " FAIL: Have a mutation or escape of a "
931+ " partial apply arg?!\n " );
904932 return false ;
905933 }
906934
@@ -920,15 +948,17 @@ namespace {
920948
921949class NonEscapingUserVisitor
922950 : public SILInstructionVisitor<NonEscapingUserVisitor, bool > {
923- llvm::SmallVector<Operand *, 32 > worklist;
924- llvm::SmallVectorImpl<SILInstruction *> &foundMutations;
951+ SmallVector<Operand *, 32 > worklist;
952+ SmallVectorImpl<Operand *> &accumulatedMutations;
953+ SmallVectorImpl<Operand *> &accumulatedEscapes;
925954 NullablePtr<Operand> currentOp;
926955
927956public:
928- NonEscapingUserVisitor (
929- Operand *initialOperand,
930- llvm::SmallVectorImpl<SILInstruction *> &foundMutations)
931- : worklist(), foundMutations(foundMutations), currentOp() {
957+ NonEscapingUserVisitor (Operand *initialOperand,
958+ SmallVectorImpl<Operand *> &accumulatedMutations,
959+ SmallVectorImpl<Operand *> &accumulatedEscapes)
960+ : worklist(), accumulatedMutations(accumulatedMutations),
961+ accumulatedEscapes (accumulatedEscapes), currentOp() {
932962 worklist.push_back (initialOperand);
933963 }
934964
@@ -937,6 +967,15 @@ class NonEscapingUserVisitor
937967 NonEscapingUserVisitor (NonEscapingUserVisitor &&) = delete;
938968 NonEscapingUserVisitor &operator =(NonEscapingUserVisitor &&) = delete ;
939969
970+ private:
971+ void markCurrentOpAsMutation () {
972+ accumulatedMutations.push_back (currentOp.get ());
973+ }
974+ void markCurrentOpAsEscape () {
975+ accumulatedEscapes.push_back (currentOp.get ());
976+ }
977+
978+ public:
940979 bool compute () {
941980 while (!worklist.empty ()) {
942981 currentOp = worklist.pop_back_val ();
@@ -964,6 +1003,7 @@ class NonEscapingUserVisitor
9641003 bool visitSILInstruction (SILInstruction *inst) {
9651004 LLVM_DEBUG (llvm::dbgs ()
9661005 << " FAIL! Have unknown escaping user: " << *inst);
1006+ markCurrentOpAsEscape ();
9671007 return false ;
9681008 }
9691009
@@ -979,7 +1019,7 @@ class NonEscapingUserVisitor
9791019#undef ALWAYS_NON_ESCAPING_INST
9801020
9811021 bool visitDeallocBoxInst (DeallocBoxInst *dbi) {
982- foundMutations. push_back (dbi );
1022+ markCurrentOpAsMutation ( );
9831023 return true ;
9841024 }
9851025
@@ -992,9 +1032,10 @@ class NonEscapingUserVisitor
9921032 if (!convention.isIndirectConvention ()) {
9931033 LLVM_DEBUG (llvm::dbgs ()
9941034 << " FAIL! Found non indirect apply user: " << *ai);
1035+ markCurrentOpAsEscape ();
9951036 return false ;
9961037 }
997- foundMutations. push_back (ai );
1038+ markCurrentOpAsMutation ( );
9981039 return true ;
9991040 }
10001041
@@ -1017,7 +1058,7 @@ class NonEscapingUserVisitor
10171058#define RECURSIVE_INST_VISITOR (MUTATING, INST ) \
10181059 bool visit##INST##Inst(INST##Inst *i) { \
10191060 if (bool (detail::MUTATING)) { \
1020- foundMutations. push_back (i); \
1061+ markCurrentOpAsMutation (); \
10211062 } \
10221063 addUsesToWorklist (i); \
10231064 return true ; \
@@ -1044,25 +1085,27 @@ class NonEscapingUserVisitor
10441085
10451086 bool visitCopyAddrInst (CopyAddrInst *cai) {
10461087 if (currentOp.get ()->getOperandNumber () == 1 || cai->isTakeOfSrc ())
1047- foundMutations. push_back (cai );
1088+ markCurrentOpAsMutation ( );
10481089 return true ;
10491090 }
10501091
10511092 bool visitStoreInst (StoreInst *si) {
10521093 if (currentOp.get ()->getOperandNumber () != 1 ) {
10531094 LLVM_DEBUG (llvm::dbgs () << " FAIL! Found store of pointer: " << *si);
1095+ markCurrentOpAsEscape ();
10541096 return false ;
10551097 }
1056- foundMutations. push_back (si );
1098+ markCurrentOpAsMutation ( );
10571099 return true ;
10581100 }
10591101
10601102 bool visitAssignInst (AssignInst *ai) {
10611103 if (currentOp.get ()->getOperandNumber () != 1 ) {
10621104 LLVM_DEBUG (llvm::dbgs () << " FAIL! Found store of pointer: " << *ai);
1105+ markCurrentOpAsEscape ();
10631106 return false ;
10641107 }
1065- foundMutations. push_back (ai );
1108+ markCurrentOpAsMutation ( );
10661109 return true ;
10671110 }
10681111};
@@ -1075,7 +1118,9 @@ class NonEscapingUserVisitor
10751118// / the Mutations vector.
10761119static bool isNonEscapingUse (Operand *initialOp,
10771120 EscapeMutationScanningState &state) {
1078- return NonEscapingUserVisitor (initialOp, state.foundMutations ).compute ();
1121+ return NonEscapingUserVisitor (initialOp, state.accumulatedMutations ,
1122+ state.accumulatedEscapes )
1123+ .compute ();
10791124}
10801125
10811126static bool isProjectBoxNonEscapingUse (ProjectBoxInst *pbi,
@@ -1097,12 +1142,12 @@ static bool isProjectBoxNonEscapingUse(ProjectBoxInst *pbi,
10971142// Top Level AllocBox Escape/Mutation Analysis
10981143// ===----------------------------------------------------------------------===//
10991144
1100- static bool scanUsesForEscapesAndMutations (Operand *op,
1101- EscapeMutationScanningState &state) {
1145+ static bool findEscapeOrMutationUses (Operand *op,
1146+ EscapeMutationScanningState &state) {
11021147 SILInstruction *user = op->getUser ();
11031148
11041149 if (auto *pai = dyn_cast<PartialApplyInst>(user)) {
1105- return isPartialApplyNonEscapingUser (op, pai, state);
1150+ return ! isPartialApplyNonEscapingUser (op, pai, state);
11061151 }
11071152
11081153 // A mark_dependence user on a partial_apply is safe.
@@ -1112,7 +1157,11 @@ static bool scanUsesForEscapesAndMutations(Operand *op,
11121157 while ((mdi = dyn_cast<MarkDependenceInst>(parent))) {
11131158 parent = mdi->getValue ();
11141159 }
1115- return isa<PartialApplyInst>(parent);
1160+ if (isa<PartialApplyInst>(parent))
1161+ return false ;
1162+ state.accumulatedEscapes .push_back (
1163+ &mdi->getOperandRef (MarkDependenceInst::Value));
1164+ return true ;
11161165 }
11171166 }
11181167
@@ -1121,9 +1170,9 @@ static bool scanUsesForEscapesAndMutations(Operand *op,
11211170 // can be seen since there is no code for reasoning about multiple
11221171 // boxes. Just put in the restriction so we are consistent.
11231172 if (state.sawProjectBoxInst )
1124- return false ;
1173+ return true ;
11251174 state.sawProjectBoxInst = true ;
1126- return isProjectBoxNonEscapingUse (pbi, state);
1175+ return ! isProjectBoxNonEscapingUse (pbi, state);
11271176 }
11281177
11291178 // Given a top level copy value use or mark_uninitialized, check all of its
@@ -1134,10 +1183,11 @@ static bool scanUsesForEscapesAndMutations(Operand *op,
11341183 // derived from a projection like instruction). In fact such a thing may not
11351184 // even make any sense!
11361185 if (isa<CopyValueInst>(user) || isa<MarkUninitializedInst>(user)) {
1137- return all_of (cast<SingleValueInstruction>(user)->getUses (),
1138- [&state](Operand *userOp) -> bool {
1139- return scanUsesForEscapesAndMutations (userOp, state);
1140- });
1186+ bool foundSomeMutations = false ;
1187+ for (auto *use : cast<SingleValueInstruction>(user)->getUses ()) {
1188+ foundSomeMutations |= findEscapeOrMutationUses (use, state);
1189+ }
1190+ return foundSomeMutations;
11411191 }
11421192
11431193 // Verify that this use does not otherwise allow the alloc_box to
@@ -1153,14 +1203,20 @@ static bool
11531203examineAllocBoxInst (AllocBoxInst *abi, ReachabilityInfo &ri,
11541204 llvm::DenseMap<PartialApplyInst *, unsigned > &im) {
11551205 LLVM_DEBUG (llvm::dbgs () << " Visiting alloc box: " << *abi);
1156- EscapeMutationScanningState state{{}, false , im};
1206+ EscapeMutationScanningState state{{}, {}, false , im};
11571207
1158- // Scan the box for interesting uses.
1159- if (any_of (abi->getUses (), [&state](Operand *op) {
1160- return !scanUsesForEscapesAndMutations (op, state);
1161- })) {
1208+ // Scan the box for escaping or mutating uses.
1209+ for (auto *use : abi->getUses ()) {
1210+ findEscapeOrMutationUses (use, state);
1211+ }
1212+
1213+ if (!state.accumulatedEscapes .empty ()) {
11621214 LLVM_DEBUG (llvm::dbgs ()
1163- << " Found an escaping use! Can not optimize this alloc box?!\n " );
1215+ << " Found escaping uses! Can not optimize this alloc box?!\n " );
1216+ while (!state.accumulatedEscapes .empty ()) {
1217+ auto *escapingUse = state.accumulatedEscapes .pop_back_val ();
1218+ LLVM_DEBUG (llvm::dbgs () << " Escaping use: " << *escapingUse->getUser ());
1219+ }
11641220 return false ;
11651221 }
11661222
@@ -1183,17 +1239,18 @@ examineAllocBoxInst(AllocBoxInst *abi, ReachabilityInfo &ri,
11831239 LLVM_DEBUG (llvm::dbgs ()
11841240 << " Checking for any mutations that invalidate captures...\n " );
11851241 // Loop over all mutations to possibly invalidate captures.
1186- for (auto *inst : state.foundMutations ) {
1242+ for (auto *use : state.accumulatedMutations ) {
11871243 auto iter = im.begin ();
11881244 while (iter != im.end ()) {
1245+ auto *user = use->getUser ();
11891246 auto *pai = iter->first ;
11901247 // The mutation invalidates a capture if it occurs in a block reachable
11911248 // from the block the partial_apply is in, or if it is in the same
11921249 // block is after the partial_apply.
1193- if (ri.isReachable (pai->getParent (), inst ->getParent ()) ||
1194- (pai->getParent () == inst ->getParent () && isAfter (pai, inst ))) {
1250+ if (ri.isReachable (pai->getParent (), user ->getParent ()) ||
1251+ (pai->getParent () == user ->getParent () && isAfter (pai, user ))) {
11951252 LLVM_DEBUG (llvm::dbgs () << " Invalidating: " << *pai);
1196- LLVM_DEBUG (llvm::dbgs () << " Because of user: " << *inst );
1253+ LLVM_DEBUG (llvm::dbgs () << " Because of user: " << *user );
11971254 auto prev = iter++;
11981255 im.erase (prev);
11991256 continue ;
0 commit comments