Skip to content

Commit 208a5b2

Browse files
committed
[SCEV] Improve handling of divisibility information from loop guards. (llvm#163021)
At the moment, the effectivness of guards that contain divisibility information (A % B == 0 ) depends on the order of the conditions. This patch makes using divisibility information independent of the order, by collecting and applying the divisibility information separately. We first collect all conditions in a vector, then collect the divisibility information from all guards. When processing other guards, we apply divisibility info collected earlier. After all guards have been processed, we add the divisibility info, rewriting the existing rewrite. This ensures we apply the divisibility info to the largest rewrite expression. This helps to improve results in a few cases, one in dtcxzyw/llvm-opt-benchmark#2921 and another one in a different large C/C++ based IR corpus. PR: llvm#163021 (cherry picked from commit d3fe1df)
1 parent a738a80 commit 208a5b2

File tree

2 files changed

+312
-81
lines changed

2 files changed

+312
-81
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 114 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -15458,26 +15458,92 @@ static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
1545815458
return SE.getConstant(*ExprVal + DivisorVal - Rem);
1545915459
}
1546015460

15461+
static bool collectDivisibilityInformation(
15462+
ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15463+
DenseMap<const SCEV *, const SCEV *> &DivInfo,
15464+
DenseMap<const SCEV *, APInt> &Multiples, ScalarEvolution &SE) {
15465+
// If we have LHS == 0, check if LHS is computing a property of some unknown
15466+
// SCEV %v which we can rewrite %v to express explicitly.
15467+
if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero()))
15468+
return false;
15469+
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15470+
// explicitly express that.
15471+
const SCEVUnknown *URemLHS = nullptr;
15472+
const SCEV *URemRHS = nullptr;
15473+
if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15474+
return false;
15475+
15476+
const SCEV *Multiple =
15477+
SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15478+
DivInfo[URemLHS] = Multiple;
15479+
if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15480+
Multiples[URemLHS] = C->getAPInt();
15481+
return true;
15482+
}
15483+
15484+
// Check if the condition is a divisibility guard (A % B == 0).
15485+
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15486+
ScalarEvolution &SE) {
15487+
const SCEV *X, *Y;
15488+
return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15489+
}
15490+
15491+
// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15492+
// recursively. This is done by aligning up/down the constant value to the
15493+
// Divisor.
15494+
static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15495+
APInt Divisor,
15496+
ScalarEvolution &SE) {
15497+
// Return true if \p Expr is a MinMax SCEV expression with a non-negative
15498+
// constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15499+
// the non-constant operand and in \p LHS the constant operand.
15500+
auto IsMinMaxSCEVWithNonNegativeConstant =
15501+
[&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15502+
const SCEV *&RHS) {
15503+
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15504+
if (MinMax->getNumOperands() != 2)
15505+
return false;
15506+
if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15507+
if (C->getAPInt().isNegative())
15508+
return false;
15509+
SCTy = MinMax->getSCEVType();
15510+
LHS = MinMax->getOperand(0);
15511+
RHS = MinMax->getOperand(1);
15512+
return true;
15513+
}
15514+
}
15515+
return false;
15516+
};
15517+
15518+
const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15519+
SCEVTypes SCTy;
15520+
if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15521+
MinMaxRHS))
15522+
return MinMaxExpr;
15523+
auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15524+
assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15525+
auto *DivisibleExpr =
15526+
IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15527+
: getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15528+
SmallVector<const SCEV *> Ops = {
15529+
applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15530+
return SE.getMinMaxExpr(SCTy, Ops);
15531+
}
15532+
1546115533
void ScalarEvolution::LoopGuards::collectFromBlock(
1546215534
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1546315535
const BasicBlock *Block, const BasicBlock *Pred,
1546415536
SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
1546515537
SmallVector<const SCEV *> ExprsToRewrite;
1546615538
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1546715539
const SCEV *RHS,
15468-
DenseMap<const SCEV *, const SCEV *>
15469-
&RewriteMap) {
15540+
DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15541+
const LoopGuards &DivGuards) {
1547015542
// WARNING: It is generally unsound to apply any wrap flags to the proposed
1547115543
// replacement SCEV which isn't directly implied by the structure of that
1547215544
// SCEV. In particular, using contextual facts to imply flags is *NOT*
1547315545
// legal. See the scoping rules for flags in the header to understand why.
1547415546

15475-
// If LHS is a constant, apply information to the other expression.
15476-
if (isa<SCEVConstant>(LHS)) {
15477-
std::swap(LHS, RHS);
15478-
Predicate = CmpInst::getSwappedPredicate(Predicate);
15479-
}
15480-
1548115547
// Check for a condition of the form (-C1 + X < C2). InstCombine will
1548215548
// create this form when combining two checks of the form (X u< C2 + C1) and
1548315549
// (X >=u C1).
@@ -15510,76 +15576,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1551015576
if (MatchRangeCheckIdiom())
1551115577
return;
1551215578

15513-
// Return true if \p Expr is a MinMax SCEV expression with a non-negative
15514-
// constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15515-
// the non-constant operand and in \p LHS the constant operand.
15516-
auto IsMinMaxSCEVWithNonNegativeConstant =
15517-
[&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15518-
const SCEV *&RHS) {
15519-
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15520-
if (MinMax->getNumOperands() != 2)
15521-
return false;
15522-
if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15523-
if (C->getAPInt().isNegative())
15524-
return false;
15525-
SCTy = MinMax->getSCEVType();
15526-
LHS = MinMax->getOperand(0);
15527-
RHS = MinMax->getOperand(1);
15528-
return true;
15529-
}
15530-
}
15531-
return false;
15532-
};
15533-
15534-
// Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15535-
// recursively. This is done by aligning up/down the constant value to the
15536-
// Divisor.
15537-
std::function<const SCEV *(const SCEV *, const SCEV *)>
15538-
ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15539-
const SCEV *Divisor) {
15540-
auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15541-
if (!ConstDivisor)
15542-
return MinMaxExpr;
15543-
const APInt &DivisorVal = ConstDivisor->getAPInt();
15544-
15545-
const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15546-
SCEVTypes SCTy;
15547-
if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15548-
MinMaxRHS))
15549-
return MinMaxExpr;
15550-
auto IsMin =
15551-
isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15552-
assert(SE.isKnownNonNegative(MinMaxLHS) &&
15553-
"Expected non-negative operand!");
15554-
auto *DivisibleExpr =
15555-
IsMin
15556-
? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE)
15557-
: getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE);
15558-
SmallVector<const SCEV *> Ops = {
15559-
ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15560-
return SE.getMinMaxExpr(SCTy, Ops);
15561-
};
15562-
15563-
// If we have LHS == 0, check if LHS is computing a property of some unknown
15564-
// SCEV %v which we can rewrite %v to express explicitly.
15565-
if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15566-
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15567-
// explicitly express that.
15568-
const SCEVUnknown *URemLHS = nullptr;
15569-
const SCEV *URemRHS = nullptr;
15570-
if (match(LHS,
15571-
m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15572-
auto I = RewriteMap.find(URemLHS);
15573-
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15574-
RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15575-
const auto *Multiple =
15576-
SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15577-
RewriteMap[URemLHS] = Multiple;
15578-
ExprsToRewrite.push_back(URemLHS);
15579-
return;
15580-
}
15581-
}
15582-
1558315579
// Do not apply information for constants or if RHS contains an AddRec.
1558415580
if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
1558515581
return;
@@ -15609,7 +15605,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1560915605
};
1561015606

1561115607
const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15612-
const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS);
15608+
// Apply divisibility information when computing the constant multiple.
15609+
const APInt &DividesBy =
15610+
SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
1561315611

1561415612
// Collect rewrites for LHS and its transitive operands based on the
1561515613
// condition.
@@ -15794,8 +15792,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1579415792

1579515793
// Now apply the information from the collected conditions to
1579615794
// Guards.RewriteMap. Conditions are processed in reverse order, so the
15797-
// earliest conditions is processed first. This ensures the SCEVs with the
15795+
// earliest conditions is processed first, except guards with divisibility
15796+
// information, which are moved to the back. This ensures the SCEVs with the
1579815797
// shortest dependency chains are constructed first.
15798+
SmallVector<std::tuple<CmpInst::Predicate, const SCEV *, const SCEV *>>
15799+
GuardsToProcess;
1579915800
for (auto [Term, EnterIfTrue] : reverse(Terms)) {
1580015801
SmallVector<Value *, 8> Worklist;
1580115802
SmallPtrSet<Value *, 8> Visited;
@@ -15810,7 +15811,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1581015811
EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
1581115812
const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
1581215813
const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15813-
CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15814+
// If LHS is a constant, apply information to the other expression.
15815+
// TODO: If LHS is not a constant, check if using CompareSCEVComplexity
15816+
// can improve results.
15817+
if (isa<SCEVConstant>(LHS)) {
15818+
std::swap(LHS, RHS);
15819+
Predicate = CmpInst::getSwappedPredicate(Predicate);
15820+
}
15821+
GuardsToProcess.emplace_back(Predicate, LHS, RHS);
1581415822
continue;
1581515823
}
1581615824

@@ -15823,6 +15831,31 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1582315831
}
1582415832
}
1582515833

15834+
// Process divisibility guards in reverse order to populate DivGuards early.
15835+
DenseMap<const SCEV *, APInt> Multiples;
15836+
LoopGuards DivGuards(SE);
15837+
for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15838+
if (!isDivisibilityGuard(LHS, RHS, SE))
15839+
continue;
15840+
collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
15841+
Multiples, SE);
15842+
}
15843+
15844+
for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15845+
CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
15846+
15847+
// Apply divisibility information last. This ensures it is applied to the
15848+
// outermost expression after other rewrites for the given value.
15849+
for (const auto &[K, Divisor] : Multiples) {
15850+
const SCEV *DivisorSCEV = SE.getConstant(Divisor);
15851+
Guards.RewriteMap[K] =
15852+
SE.getMulExpr(SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(
15853+
Guards.rewrite(K), Divisor, SE),
15854+
DivisorSCEV),
15855+
DivisorSCEV);
15856+
ExprsToRewrite.push_back(K);
15857+
}
15858+
1582615859
// Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
1582715860
// the replacement expressions are contained in the ranges of the replaced
1582815861
// expressions.

0 commit comments

Comments
 (0)