Skip to content

Commit 98e17c6

Browse files
committed
[SCEV] Rewrite A - B = UMin(1, A - B) lazily for A != B loop guards. (llvm#163787)
Follow-up to 2d02726 (llvm#160500) Creating the SCEV subtraction eagerly is very expensive. To soften the blow, just collect a map with inequalities and check if we can apply the subtract rewrite when rewriting SCEVAddExpr. Restores most of the regression: http://llvm-compile-time-tracker.com/compare.php?from=0792478e4e133be96650444f3264e89d002fc058&to=7fca35db60fe6f423ea6051b45226046c067c252&stat=instructions:u stage1-O3: -0.10% stage1-ReleaseThinLTO: -0.09% stage1-ReleaseLTO-g: -0.10% stage1-O0-g: +0.02% stage1-aarch64-O3: -0.09% stage1-aarch64-O0-g: +0.00% stage2-O3: -0.17% stage2-O0-g: -0.05% stage2-clang: -0.07% There is still some negative impact compared to before 2d02726, but there's probably not much we could do reduce this even more. Compile-time improvement with 2d02726 reverted on top of the current PR: http://llvm-compile-time-tracker.com/compare.php?from=7fca35db60fe6f423ea6051b45226046c067c252&to=98dd152bdfc76b30d00190d3850d89406ca3c21f&stat=instructions:u stage1-O3: 60628M (-0.03%) stage1-ReleaseThinLTO: 76388M (-0.04%) stage1-ReleaseLTO-g: 89228M (-0.02%) stage1-O0-g: 18523M (-0.03%) stage1-aarch64-O3: 67623M (-0.03%) stage1-aarch64-O0-g: 22595M (+0.01%) stage2-O3: 52336M (+0.01%) stage2-O0-g: 16174M (+0.00%) stage2-clang: 34890032M (-0.03%) PR: llvm#163787 (cherry picked from commit a5d3522)
1 parent 96dbb37 commit 98e17c6

File tree

3 files changed

+189
-11
lines changed

3 files changed

+189
-11
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,7 @@ class ScalarEvolution {
13431343

13441344
class LoopGuards {
13451345
DenseMap<const SCEV *, const SCEV *> RewriteMap;
1346+
SmallDenseSet<std::pair<const SCEV *, const SCEV *>> NotEqual;
13461347
bool PreserveNUW = false;
13471348
bool PreserveNSW = false;
13481349
ScalarEvolution &SE;

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15699,19 +15699,26 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1569915699
GetNextSCEVDividesByDivisor(One, DividesBy);
1570015700
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
1570115701
} else {
15702+
// LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15703+
// but creating the subtraction eagerly is expensive. Track the
15704+
// inequalities in a separate map, and materialize the rewrite lazily
15705+
// when encountering a suitable subtraction while re-writing.
1570215706
if (LHS->getType()->isPointerTy()) {
1570315707
LHS = SE.getLosslessPtrToIntExpr(LHS);
1570415708
RHS = SE.getLosslessPtrToIntExpr(RHS);
1570515709
if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS))
1570615710
break;
1570715711
}
15708-
auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) {
15709-
const SCEV *Sub = SE.getMinusSCEV(A, B);
15710-
AddRewrite(Sub, Sub,
15711-
SE.getUMaxExpr(Sub, SE.getOne(From->getType())));
15712-
};
15713-
AddSubRewrite(LHS, RHS);
15714-
AddSubRewrite(RHS, LHS);
15712+
const SCEVConstant *C;
15713+
const SCEV *A, *B;
15714+
if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) &&
15715+
match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) {
15716+
RHS = A;
15717+
LHS = B;
15718+
}
15719+
if (LHS > RHS)
15720+
std::swap(LHS, RHS);
15721+
Guards.NotEqual.insert({LHS, RHS});
1571515722
continue;
1571615723
}
1571715724
break;
@@ -15845,13 +15852,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
1584515852
class SCEVLoopGuardRewriter
1584615853
: public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
1584715854
const DenseMap<const SCEV *, const SCEV *> &Map;
15855+
const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> &NotEqual;
1584815856

1584915857
SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
1585015858

1585115859
public:
1585215860
SCEVLoopGuardRewriter(ScalarEvolution &SE,
1585315861
const ScalarEvolution::LoopGuards &Guards)
15854-
: SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15862+
: SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
15863+
NotEqual(Guards.NotEqual) {
1585515864
if (Guards.PreserveNUW)
1585615865
FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
1585715866
if (Guards.PreserveNSW)
@@ -15906,14 +15915,36 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
1590615915
}
1590715916

1590815917
const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15918+
// Helper to check if S is a subtraction (A - B) where A != B, and if so,
15919+
// return UMax(S, 1).
15920+
auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
15921+
const SCEV *LHS, *RHS;
15922+
if (MatchBinarySub(S, LHS, RHS)) {
15923+
if (LHS > RHS)
15924+
std::swap(LHS, RHS);
15925+
if (NotEqual.contains({LHS, RHS}))
15926+
return SE.getUMaxExpr(S, SE.getOne(S->getType()));
15927+
}
15928+
return nullptr;
15929+
};
15930+
15931+
// Check if Expr itself is a subtraction pattern with guard info.
15932+
if (const SCEV *Rewritten = RewriteSubtraction(Expr))
15933+
return Rewritten;
15934+
1590915935
// Trip count expressions sometimes consist of adding 3 operands, i.e.
1591015936
// (Const + A + B). There may be guard info for A + B, and if so, apply
1591115937
// it.
1591215938
// TODO: Could more generally apply guards to Add sub-expressions.
1591315939
if (isa<SCEVConstant>(Expr->getOperand(0)) &&
1591415940
Expr->getNumOperands() == 3) {
15915-
if (const SCEV *S = Map.lookup(
15916-
SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
15941+
const SCEV *Add =
15942+
SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
15943+
if (const SCEV *Rewritten = RewriteSubtraction(Add))
15944+
return SE.getAddExpr(
15945+
Expr->getOperand(0), Rewritten,
15946+
ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
15947+
if (const SCEV *S = Map.lookup(Add))
1591715948
return SE.getAddExpr(Expr->getOperand(0), S);
1591815949
}
1591915950
SmallVector<const SCEV *, 2> Operands;
@@ -15948,7 +15979,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
1594815979
}
1594915980
};
1595015981

15951-
if (RewriteMap.empty())
15982+
if (RewriteMap.empty() && NotEqual.empty())
1595215983
return Expr;
1595315984

1595415985
SCEVLoopGuardRewriter Rewriter(SE, *this);

llvm/test/Transforms/IndVarSimplify/pointer-loop-guards.ll

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,149 @@ exit:
4848
%res = phi i64 [ 0, %entry ], [ %i64.iv, %loop.latch ], [ 0, %loop.header ]
4949
ret i64 %res
5050
}
51+
52+
define void @test_sub_cmp(ptr align 8 %start, ptr %end) {
53+
; CHECK-LABEL: define void @test_sub_cmp(
54+
; CHECK-SAME: ptr align 8 [[START:%.*]], ptr [[END:%.*]]) {
55+
; CHECK-NEXT: [[ENTRY:.*:]]
56+
; CHECK-NEXT: [[START_INT:%.*]] = ptrtoint ptr [[START]] to i64
57+
; CHECK-NEXT: [[END_INT:%.*]] = ptrtoint ptr [[END]] to i64
58+
; CHECK-NEXT: [[PTR_DIFF:%.*]] = sub i64 [[START_INT]], [[END_INT]]
59+
; CHECK-NEXT: [[CMP_ENTRY:%.*]] = icmp eq ptr [[START]], [[END]]
60+
; CHECK-NEXT: br i1 [[CMP_ENTRY]], label %[[EXIT:.*]], label %[[LOOP_HEADER_PREHEADER:.*]]
61+
; CHECK: [[LOOP_HEADER_PREHEADER]]:
62+
; CHECK-NEXT: br label %[[LOOP_HEADER:.*]]
63+
; CHECK: [[LOOP_HEADER]]:
64+
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[IV_NEXT:%.*]], %[[LOOP_LATCH:.*]] ], [ 0, %[[LOOP_HEADER_PREHEADER]] ]
65+
; CHECK-NEXT: [[C_1:%.*]] = call i1 @cond()
66+
; CHECK-NEXT: br i1 [[C_1]], label %[[EXIT_EARLY:.*]], label %[[LOOP_LATCH]]
67+
; CHECK: [[LOOP_LATCH]]:
68+
; CHECK-NEXT: [[IV_NEXT]] = add nuw i64 [[IV]], 1
69+
; CHECK-NEXT: [[CMP_LATCH:%.*]] = icmp ult i64 [[IV_NEXT]], [[PTR_DIFF]]
70+
; CHECK-NEXT: br i1 [[CMP_LATCH]], label %[[LOOP_HEADER]], label %[[EXIT_LOOPEXIT:.*]]
71+
; CHECK: [[EXIT_EARLY]]:
72+
; CHECK-NEXT: br label %[[EXIT]]
73+
; CHECK: [[EXIT_LOOPEXIT]]:
74+
; CHECK-NEXT: br label %[[EXIT]]
75+
; CHECK: [[EXIT]]:
76+
; CHECK-NEXT: ret void
77+
;
78+
; N32-LABEL: define void @test_sub_cmp(
79+
; N32-SAME: ptr align 8 [[START:%.*]], ptr [[END:%.*]]) {
80+
; N32-NEXT: [[ENTRY:.*:]]
81+
; N32-NEXT: [[START_INT:%.*]] = ptrtoint ptr [[START]] to i64
82+
; N32-NEXT: [[END_INT:%.*]] = ptrtoint ptr [[END]] to i64
83+
; N32-NEXT: [[PTR_DIFF:%.*]] = sub i64 [[START_INT]], [[END_INT]]
84+
; N32-NEXT: [[CMP_ENTRY:%.*]] = icmp eq ptr [[START]], [[END]]
85+
; N32-NEXT: br i1 [[CMP_ENTRY]], label %[[EXIT:.*]], label %[[LOOP_HEADER_PREHEADER:.*]]
86+
; N32: [[LOOP_HEADER_PREHEADER]]:
87+
; N32-NEXT: br label %[[LOOP_HEADER:.*]]
88+
; N32: [[LOOP_HEADER]]:
89+
; N32-NEXT: [[IV:%.*]] = phi i64 [ [[IV_NEXT:%.*]], %[[LOOP_LATCH:.*]] ], [ 0, %[[LOOP_HEADER_PREHEADER]] ]
90+
; N32-NEXT: [[C_1:%.*]] = call i1 @cond()
91+
; N32-NEXT: br i1 [[C_1]], label %[[EXIT_EARLY:.*]], label %[[LOOP_LATCH]]
92+
; N32: [[LOOP_LATCH]]:
93+
; N32-NEXT: [[IV_NEXT]] = add nuw i64 [[IV]], 1
94+
; N32-NEXT: [[EXITCOND:%.*]] = icmp ne i64 [[IV_NEXT]], [[PTR_DIFF]]
95+
; N32-NEXT: br i1 [[EXITCOND]], label %[[LOOP_HEADER]], label %[[EXIT_LOOPEXIT:.*]]
96+
; N32: [[EXIT_EARLY]]:
97+
; N32-NEXT: br label %[[EXIT]]
98+
; N32: [[EXIT_LOOPEXIT]]:
99+
; N32-NEXT: br label %[[EXIT]]
100+
; N32: [[EXIT]]:
101+
; N32-NEXT: ret void
102+
;
103+
entry:
104+
%start.int = ptrtoint ptr %start to i64
105+
%end.int = ptrtoint ptr %end to i64
106+
%ptr.diff = sub i64 %start.int, %end.int
107+
%cmp.entry = icmp eq ptr %start, %end
108+
br i1 %cmp.entry, label %exit, label %loop.header
109+
110+
loop.header:
111+
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop.latch ]
112+
%c.1 = call i1 @cond()
113+
br i1 %c.1, label %exit.early, label %loop.latch
114+
115+
loop.latch:
116+
%iv.next = add i64 %iv, 1
117+
%cmp.latch = icmp ult i64 %iv.next, %ptr.diff
118+
br i1 %cmp.latch, label %loop.header, label %exit
119+
120+
exit.early:
121+
br label %exit
122+
123+
exit:
124+
ret void
125+
}
126+
127+
128+
define void @test_ptr_diff_with_assume(ptr align 8 %start, ptr align 8 %end, ptr %P) {
129+
; CHECK-LABEL: define void @test_ptr_diff_with_assume(
130+
; CHECK-SAME: ptr align 8 [[START:%.*]], ptr align 8 [[END:%.*]], ptr [[P:%.*]]) {
131+
; CHECK-NEXT: [[ENTRY:.*:]]
132+
; CHECK-NEXT: [[START_INT:%.*]] = ptrtoint ptr [[START]] to i64
133+
; CHECK-NEXT: [[END_INT:%.*]] = ptrtoint ptr [[END]] to i64
134+
; CHECK-NEXT: [[PTR_DIFF:%.*]] = sub i64 [[START_INT]], [[END_INT]]
135+
; CHECK-NEXT: [[DIFF_CMP:%.*]] = icmp ult i64 [[PTR_DIFF]], 2
136+
; CHECK-NEXT: call void @llvm.assume(i1 [[DIFF_CMP]])
137+
; CHECK-NEXT: [[COMPUTED_END:%.*]] = getelementptr i8, ptr [[START]], i64 [[PTR_DIFF]]
138+
; CHECK-NEXT: [[ENTRY_CMP:%.*]] = icmp eq ptr [[START]], [[END]]
139+
; CHECK-NEXT: br i1 [[ENTRY_CMP]], label %[[EXIT:.*]], label %[[LOOP_BODY_PREHEADER:.*]]
140+
; CHECK: [[LOOP_BODY_PREHEADER]]:
141+
; CHECK-NEXT: br label %[[LOOP_BODY:.*]]
142+
; CHECK: [[LOOP_BODY]]:
143+
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[IV_NEXT:%.*]], %[[LOOP_BODY]] ], [ [[START]], %[[LOOP_BODY_PREHEADER]] ]
144+
; CHECK-NEXT: [[TMP0:%.*]] = call i1 @cond()
145+
; CHECK-NEXT: [[IV_NEXT]] = getelementptr i8, ptr [[IV]], i64 1
146+
; CHECK-NEXT: [[LOOP_CMP:%.*]] = icmp eq ptr [[IV_NEXT]], [[COMPUTED_END]]
147+
; CHECK-NEXT: br i1 [[LOOP_CMP]], label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP_BODY]]
148+
; CHECK: [[EXIT_LOOPEXIT]]:
149+
; CHECK-NEXT: br label %[[EXIT]]
150+
; CHECK: [[EXIT]]:
151+
; CHECK-NEXT: ret void
152+
;
153+
; N32-LABEL: define void @test_ptr_diff_with_assume(
154+
; N32-SAME: ptr align 8 [[START:%.*]], ptr align 8 [[END:%.*]], ptr [[P:%.*]]) {
155+
; N32-NEXT: [[ENTRY:.*:]]
156+
; N32-NEXT: [[START_INT:%.*]] = ptrtoint ptr [[START]] to i64
157+
; N32-NEXT: [[END_INT:%.*]] = ptrtoint ptr [[END]] to i64
158+
; N32-NEXT: [[PTR_DIFF:%.*]] = sub i64 [[START_INT]], [[END_INT]]
159+
; N32-NEXT: [[DIFF_CMP:%.*]] = icmp ult i64 [[PTR_DIFF]], 2
160+
; N32-NEXT: call void @llvm.assume(i1 [[DIFF_CMP]])
161+
; N32-NEXT: [[COMPUTED_END:%.*]] = getelementptr i8, ptr [[START]], i64 [[PTR_DIFF]]
162+
; N32-NEXT: [[ENTRY_CMP:%.*]] = icmp eq ptr [[START]], [[END]]
163+
; N32-NEXT: br i1 [[ENTRY_CMP]], label %[[EXIT:.*]], label %[[LOOP_BODY_PREHEADER:.*]]
164+
; N32: [[LOOP_BODY_PREHEADER]]:
165+
; N32-NEXT: br label %[[LOOP_BODY:.*]]
166+
; N32: [[LOOP_BODY]]:
167+
; N32-NEXT: [[IV:%.*]] = phi ptr [ [[IV_NEXT:%.*]], %[[LOOP_BODY]] ], [ [[START]], %[[LOOP_BODY_PREHEADER]] ]
168+
; N32-NEXT: [[TMP0:%.*]] = call i1 @cond()
169+
; N32-NEXT: [[IV_NEXT]] = getelementptr i8, ptr [[IV]], i64 1
170+
; N32-NEXT: [[LOOP_CMP:%.*]] = icmp eq ptr [[IV_NEXT]], [[COMPUTED_END]]
171+
; N32-NEXT: br i1 [[LOOP_CMP]], label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP_BODY]]
172+
; N32: [[EXIT_LOOPEXIT]]:
173+
; N32-NEXT: br label %[[EXIT]]
174+
; N32: [[EXIT]]:
175+
; N32-NEXT: ret void
176+
;
177+
entry:
178+
%start.int = ptrtoint ptr %start to i64
179+
%end.int = ptrtoint ptr %end to i64
180+
%ptr.diff = sub i64 %start.int, %end.int
181+
%diff.cmp = icmp ult i64 %ptr.diff, 2
182+
call void @llvm.assume(i1 %diff.cmp)
183+
%computed.end = getelementptr i8, ptr %start, i64 %ptr.diff
184+
%entry.cmp = icmp eq ptr %start, %end
185+
br i1 %entry.cmp, label %exit, label %loop.body
186+
187+
loop.body:
188+
%iv = phi ptr [ %start, %entry ], [ %iv.next, %loop.body ]
189+
call i1 @cond()
190+
%iv.next = getelementptr i8, ptr %iv, i64 1
191+
%loop.cmp = icmp eq ptr %iv.next, %computed.end
192+
br i1 %loop.cmp, label %exit, label %loop.body
193+
194+
exit:
195+
ret void
196+
}

0 commit comments

Comments
 (0)