@@ -436,6 +436,18 @@ bool COWArrayOpt::checkSafeArrayAddressUses(UserList &AddressUsers) {
436436 return true ;
437437}
438438
439+ template <typename UserRange>
440+ ArraySemanticsCall getEndMutationCall (const UserRange &AddressUsers) {
441+ for (auto *UseInst : AddressUsers) {
442+ if (auto *AI = dyn_cast<ApplyInst>(UseInst)) {
443+ ArraySemanticsCall ASC (AI);
444+ if (ASC.getKind () == ArrayCallKind::kEndMutation )
445+ return ASC;
446+ }
447+ }
448+ return ArraySemanticsCall ();
449+ }
450+
439451// / Returns true if this instruction is a safe array use if all of its users are
440452// / also safe array users.
441453static SILValue isTransitiveSafeUser (SILInstruction *I) {
@@ -811,8 +823,14 @@ void COWArrayOpt::hoistAddressProjections(Operand &ArrayOp) {
811823 }
812824}
813825
814- // / Check if this call to "make_mutable" is hoistable, and move it, or delete it
815- // / if it's already hoisted.
826+ // / Check if this call to "make_mutable" is hoistable, and copy it, along with
827+ // / the corresponding end_mutation call, to the loop pre-header.
828+ // /
829+ // / The origial make_mutable/end_mutation calls remain in the loop, because
830+ // / removing them would violate the COW representation rules.
831+ // / Having those calls in the pre-header will then enable COWOpts (after
832+ // / inlining) to constant fold the uniqueness check of the begin_cow_mutation
833+ // / in the loop.
816834bool COWArrayOpt::hoistMakeMutable (ArraySemanticsCall MakeMutable,
817835 bool dominatesExits) {
818836 LLVM_DEBUG (llvm::dbgs () << " Checking mutable array: " <<CurrentArrayAddr);
@@ -872,6 +890,18 @@ bool COWArrayOpt::hoistMakeMutable(ArraySemanticsCall MakeMutable,
872890 return false ;
873891 }
874892
893+ auto ArrayUsers = llvm::map_range (MakeMutable.getSelf ()->getUses (),
894+ ValueBase::UseToUser ());
895+
896+ // There should be a call to end_mutation. Find it so that we can copy it to
897+ // the pre-header.
898+ ArraySemanticsCall EndMutation = getEndMutationCall (ArrayUsers);
899+ if (!EndMutation) {
900+ EndMutation = getEndMutationCall (StructUses.StructAddressUsers );
901+ if (!EndMutation)
902+ return false ;
903+ }
904+
875905 // Hoist the make_mutable.
876906 LLVM_DEBUG (llvm::dbgs () << " Hoisting make_mutable: " << *MakeMutable);
877907
@@ -880,12 +910,18 @@ bool COWArrayOpt::hoistMakeMutable(ArraySemanticsCall MakeMutable,
880910 assert (MakeMutable.canHoist (Preheader->getTerminator (), DomTree) &&
881911 " Should be able to hoist make_mutable" );
882912
883- MakeMutable.hoist (Preheader->getTerminator (), DomTree);
913+ // Copy the make_mutable and end_mutation calls to the pre-header.
914+ TermInst *insertionPoint = Preheader->getTerminator ();
915+ ApplyInst *hoistedMM = MakeMutable.copyTo (insertionPoint, DomTree);
916+ ApplyInst *EMInst = EndMutation;
917+ ApplyInst *hoistedEM = cast<ApplyInst>(EMInst->clone (insertionPoint));
918+ hoistedEM->setArgument (0 , hoistedMM->getArgument (0 ));
919+ placeFuncRef (hoistedEM, DomTree);
884920
885921 // Register array loads. This is needed for hoisting make_mutable calls of
886922 // inner arrays in the two-dimensional case.
887923 if (arrayContainerIsUnique &&
888- StructUses.hasSingleAddressUse ((ApplyInst *)MakeMutable)) {
924+ StructUses.hasOnlyAddressUses ((ApplyInst *)MakeMutable, EMInst )) {
889925 for (auto use : MakeMutable.getSelf ()->getUses ()) {
890926 if (auto *LI = dyn_cast<LoadInst>(use->getUser ()))
891927 HoistableLoads.insert (LI);
@@ -917,39 +953,33 @@ bool COWArrayOpt::run() {
917953 // is only mapped to a call once the analysis has determined that no
918954 // make_mutable calls are required within the loop body for that array.
919955 llvm::SmallDenseMap<SILValue, ApplyInst*> ArrayMakeMutableMap;
920-
956+
957+ llvm::SmallVector<ArraySemanticsCall, 8 > makeMutableCalls;
958+
921959 for (auto *BB : Loop->getBlocks ()) {
922960 if (ColdBlocks.isCold (BB))
923961 continue ;
924- bool dominatesExits = dominatesExitingBlocks (BB);
925- for ( auto II = BB-> begin (), IE = BB-> end (); II != IE;) {
926- // Inst may be moved by hoistMakeMutable .
927- SILInstruction *Inst = &*II;
928- ++II ;
929- ArraySemanticsCall MakeMutableCall (Inst, " array.make_mutable " );
930- if (! MakeMutableCall)
931- continue ;
962+
963+ // Instructions are getting moved around. To not mess with iterator
964+ // invalidation, first collect all calls, and then do the transformation .
965+ for ( SILInstruction &I : *BB) {
966+ ArraySemanticsCall MakeMutableCall (&I, " array.make_mutable " ) ;
967+ if (MakeMutableCall)
968+ makeMutableCalls. push_back ( MakeMutableCall);
969+ }
932970
971+ bool dominatesExits = dominatesExitingBlocks (BB);
972+ for (ArraySemanticsCall MakeMutableCall : makeMutableCalls) {
933973 CurrentArrayAddr = MakeMutableCall.getSelf ();
934974 auto HoistedCallEntry = ArrayMakeMutableMap.find (CurrentArrayAddr);
935975 if (HoistedCallEntry == ArrayMakeMutableMap.end ()) {
936- if (!hoistMakeMutable (MakeMutableCall, dominatesExits)) {
976+ if (hoistMakeMutable (MakeMutableCall, dominatesExits)) {
977+ ArrayMakeMutableMap[CurrentArrayAddr] = MakeMutableCall;
978+ HasChanged = true ;
979+ } else {
937980 ArrayMakeMutableMap[CurrentArrayAddr] = nullptr ;
938- continue ;
939981 }
940-
941- ArrayMakeMutableMap[CurrentArrayAddr] = MakeMutableCall;
942- HasChanged = true ;
943- continue ;
944982 }
945-
946- if (!HoistedCallEntry->second )
947- continue ;
948-
949- LLVM_DEBUG (llvm::dbgs () << " Removing make_mutable call: "
950- << *MakeMutableCall);
951- MakeMutableCall.removeCall ();
952- HasChanged = true ;
953983 }
954984 }
955985 return HasChanged;
0 commit comments