@@ -15458,26 +15458,92 @@ static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
1545815458 return SE.getConstant(*ExprVal + DivisorVal - Rem);
1545915459}
1546015460
15461+ static bool collectDivisibilityInformation(
15462+ ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15463+ DenseMap<const SCEV *, const SCEV *> &DivInfo,
15464+ DenseMap<const SCEV *, APInt> &Multiples, ScalarEvolution &SE) {
15465+ // If we have LHS == 0, check if LHS is computing a property of some unknown
15466+ // SCEV %v which we can rewrite %v to express explicitly.
15467+ if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero()))
15468+ return false;
15469+ // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15470+ // explicitly express that.
15471+ const SCEVUnknown *URemLHS = nullptr;
15472+ const SCEV *URemRHS = nullptr;
15473+ if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15474+ return false;
15475+
15476+ const SCEV *Multiple =
15477+ SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15478+ DivInfo[URemLHS] = Multiple;
15479+ if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15480+ Multiples[URemLHS] = C->getAPInt();
15481+ return true;
15482+ }
15483+
15484+ // Check if the condition is a divisibility guard (A % B == 0).
15485+ static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15486+ ScalarEvolution &SE) {
15487+ const SCEV *X, *Y;
15488+ return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15489+ }
15490+
15491+ // Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15492+ // recursively. This is done by aligning up/down the constant value to the
15493+ // Divisor.
15494+ static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15495+ APInt Divisor,
15496+ ScalarEvolution &SE) {
15497+ // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15498+ // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15499+ // the non-constant operand and in \p LHS the constant operand.
15500+ auto IsMinMaxSCEVWithNonNegativeConstant =
15501+ [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15502+ const SCEV *&RHS) {
15503+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15504+ if (MinMax->getNumOperands() != 2)
15505+ return false;
15506+ if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15507+ if (C->getAPInt().isNegative())
15508+ return false;
15509+ SCTy = MinMax->getSCEVType();
15510+ LHS = MinMax->getOperand(0);
15511+ RHS = MinMax->getOperand(1);
15512+ return true;
15513+ }
15514+ }
15515+ return false;
15516+ };
15517+
15518+ const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15519+ SCEVTypes SCTy;
15520+ if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15521+ MinMaxRHS))
15522+ return MinMaxExpr;
15523+ auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15524+ assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15525+ auto *DivisibleExpr =
15526+ IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15527+ : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15528+ SmallVector<const SCEV *> Ops = {
15529+ applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15530+ return SE.getMinMaxExpr(SCTy, Ops);
15531+ }
15532+
1546115533void ScalarEvolution::LoopGuards::collectFromBlock(
1546215534 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1546315535 const BasicBlock *Block, const BasicBlock *Pred,
1546415536 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
1546515537 SmallVector<const SCEV *> ExprsToRewrite;
1546615538 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1546715539 const SCEV *RHS,
15468- DenseMap<const SCEV *, const SCEV *>
15469- &RewriteMap ) {
15540+ DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15541+ const LoopGuards &DivGuards ) {
1547015542 // WARNING: It is generally unsound to apply any wrap flags to the proposed
1547115543 // replacement SCEV which isn't directly implied by the structure of that
1547215544 // SCEV. In particular, using contextual facts to imply flags is *NOT*
1547315545 // legal. See the scoping rules for flags in the header to understand why.
1547415546
15475- // If LHS is a constant, apply information to the other expression.
15476- if (isa<SCEVConstant>(LHS)) {
15477- std::swap(LHS, RHS);
15478- Predicate = CmpInst::getSwappedPredicate(Predicate);
15479- }
15480-
1548115547 // Check for a condition of the form (-C1 + X < C2). InstCombine will
1548215548 // create this form when combining two checks of the form (X u< C2 + C1) and
1548315549 // (X >=u C1).
@@ -15510,76 +15576,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1551015576 if (MatchRangeCheckIdiom())
1551115577 return;
1551215578
15513- // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15514- // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15515- // the non-constant operand and in \p LHS the constant operand.
15516- auto IsMinMaxSCEVWithNonNegativeConstant =
15517- [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15518- const SCEV *&RHS) {
15519- if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15520- if (MinMax->getNumOperands() != 2)
15521- return false;
15522- if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15523- if (C->getAPInt().isNegative())
15524- return false;
15525- SCTy = MinMax->getSCEVType();
15526- LHS = MinMax->getOperand(0);
15527- RHS = MinMax->getOperand(1);
15528- return true;
15529- }
15530- }
15531- return false;
15532- };
15533-
15534- // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15535- // recursively. This is done by aligning up/down the constant value to the
15536- // Divisor.
15537- std::function<const SCEV *(const SCEV *, const SCEV *)>
15538- ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15539- const SCEV *Divisor) {
15540- auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15541- if (!ConstDivisor)
15542- return MinMaxExpr;
15543- const APInt &DivisorVal = ConstDivisor->getAPInt();
15544-
15545- const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15546- SCEVTypes SCTy;
15547- if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15548- MinMaxRHS))
15549- return MinMaxExpr;
15550- auto IsMin =
15551- isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15552- assert(SE.isKnownNonNegative(MinMaxLHS) &&
15553- "Expected non-negative operand!");
15554- auto *DivisibleExpr =
15555- IsMin
15556- ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE)
15557- : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE);
15558- SmallVector<const SCEV *> Ops = {
15559- ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15560- return SE.getMinMaxExpr(SCTy, Ops);
15561- };
15562-
15563- // If we have LHS == 0, check if LHS is computing a property of some unknown
15564- // SCEV %v which we can rewrite %v to express explicitly.
15565- if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15566- // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15567- // explicitly express that.
15568- const SCEVUnknown *URemLHS = nullptr;
15569- const SCEV *URemRHS = nullptr;
15570- if (match(LHS,
15571- m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15572- auto I = RewriteMap.find(URemLHS);
15573- const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15574- RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15575- const auto *Multiple =
15576- SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15577- RewriteMap[URemLHS] = Multiple;
15578- ExprsToRewrite.push_back(URemLHS);
15579- return;
15580- }
15581- }
15582-
1558315579 // Do not apply information for constants or if RHS contains an AddRec.
1558415580 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
1558515581 return;
@@ -15609,7 +15605,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1560915605 };
1561015606
1561115607 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15612- const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS);
15608+ // Apply divisibility information when computing the constant multiple.
15609+ const APInt &DividesBy =
15610+ SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
1561315611
1561415612 // Collect rewrites for LHS and its transitive operands based on the
1561515613 // condition.
@@ -15794,8 +15792,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1579415792
1579515793 // Now apply the information from the collected conditions to
1579615794 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15797- // earliest conditions is processed first. This ensures the SCEVs with the
15795+ // earliest conditions is processed first, except guards with divisibility
15796+ // information, which are moved to the back. This ensures the SCEVs with the
1579815797 // shortest dependency chains are constructed first.
15798+ SmallVector<std::tuple<CmpInst::Predicate, const SCEV *, const SCEV *>>
15799+ GuardsToProcess;
1579915800 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
1580015801 SmallVector<Value *, 8> Worklist;
1580115802 SmallPtrSet<Value *, 8> Visited;
@@ -15810,7 +15811,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1581015811 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
1581115812 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
1581215813 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15813- CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15814+ // If LHS is a constant, apply information to the other expression.
15815+ // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
15816+ // can improve results.
15817+ if (isa<SCEVConstant>(LHS)) {
15818+ std::swap(LHS, RHS);
15819+ Predicate = CmpInst::getSwappedPredicate(Predicate);
15820+ }
15821+ GuardsToProcess.emplace_back(Predicate, LHS, RHS);
1581415822 continue;
1581515823 }
1581615824
@@ -15823,6 +15831,31 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1582315831 }
1582415832 }
1582515833
15834+ // Process divisibility guards in reverse order to populate DivGuards early.
15835+ DenseMap<const SCEV *, APInt> Multiples;
15836+ LoopGuards DivGuards(SE);
15837+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15838+ if (!isDivisibilityGuard(LHS, RHS, SE))
15839+ continue;
15840+ collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
15841+ Multiples, SE);
15842+ }
15843+
15844+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15845+ CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
15846+
15847+ // Apply divisibility information last. This ensures it is applied to the
15848+ // outermost expression after other rewrites for the given value.
15849+ for (const auto &[K, Divisor] : Multiples) {
15850+ const SCEV *DivisorSCEV = SE.getConstant(Divisor);
15851+ Guards.RewriteMap[K] =
15852+ SE.getMulExpr(SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(
15853+ Guards.rewrite(K), Divisor, SE),
15854+ DivisorSCEV),
15855+ DivisorSCEV);
15856+ ExprsToRewrite.push_back(K);
15857+ }
15858+
1582615859 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
1582715860 // the replacement expressions are contained in the ranges of the replaced
1582815861 // expressions.
0 commit comments