@@ -15426,6 +15426,23 @@ void ScalarEvolution::LoopGuards::collectFromPHI(
1542615426 }
1542715427}
1542815428
15429+ // Return a new SCEV that modifies \p Expr to the closest number divides by
15430+ // \p Divisor and greater or equal than Expr. For now, only handle constant
15431+ // Expr.
15432+ static const SCEV *getNextSCEVDividesByDivisor(const SCEV *Expr,
15433+ const APInt &DivisorVal,
15434+ ScalarEvolution &SE) {
15435+ const APInt *ExprVal;
15436+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15437+ DivisorVal.isNonPositive())
15438+ return Expr;
15439+ APInt Rem = ExprVal->urem(DivisorVal);
15440+ if (Rem.isZero())
15441+ return Expr;
15442+ // return the SCEV: Expr + Divisor - Expr % Divisor
15443+ return SE.getConstant(*ExprVal + DivisorVal - Rem);
15444+ }
15445+
1542915446void ScalarEvolution::LoopGuards::collectFromBlock(
1543015447 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1543115448 const BasicBlock *Block, const BasicBlock *Pred,
@@ -15499,22 +15516,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1549915516 return false;
1550015517 };
1550115518
15502- // Return a new SCEV that modifies \p Expr to the closest number divides by
15503- // \p Divisor and greater or equal than Expr. For now, only handle constant
15504- // Expr.
15505- auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15506- const APInt &DivisorVal) {
15507- const APInt *ExprVal;
15508- if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15509- DivisorVal.isNonPositive())
15510- return Expr;
15511- APInt Rem = ExprVal->urem(DivisorVal);
15512- if (Rem.isZero())
15513- return Expr;
15514- // return the SCEV: Expr + Divisor - Expr % Divisor
15515- return SE.getConstant(*ExprVal + DivisorVal - Rem);
15516- };
15517-
1551815519 // Return a new SCEV that modifies \p Expr to the closest number divides by
1551915520 // \p Divisor and less or equal than Expr. For now, only handle constant
1552015521 // Expr.
@@ -15551,7 +15552,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1555115552 "Expected non-negative operand!");
1555215553 auto *DivisibleExpr =
1555315554 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal)
15554- : GetNextSCEVDividesByDivisor (MinMaxLHS, DivisorVal);
15555+ : getNextSCEVDividesByDivisor (MinMaxLHS, DivisorVal, SE );
1555515556 SmallVector<const SCEV *> Ops = {
1555615557 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
1555715558 return SE.getMinMaxExpr(SCTy, Ops);
@@ -15634,15 +15635,15 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1563415635 case CmpInst::ICMP_UGT:
1563515636 case CmpInst::ICMP_SGT:
1563615637 RHS = SE.getAddExpr(RHS, One);
15637- RHS = GetNextSCEVDividesByDivisor (RHS, DividesBy);
15638+ RHS = getNextSCEVDividesByDivisor (RHS, DividesBy, SE );
1563815639 break;
1563915640 case CmpInst::ICMP_ULE:
1564015641 case CmpInst::ICMP_SLE:
1564115642 RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy);
1564215643 break;
1564315644 case CmpInst::ICMP_UGE:
1564415645 case CmpInst::ICMP_SGE:
15645- RHS = GetNextSCEVDividesByDivisor (RHS, DividesBy);
15646+ RHS = getNextSCEVDividesByDivisor (RHS, DividesBy, SE );
1564615647 break;
1564715648 default:
1564815649 break;
@@ -15696,7 +15697,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1569615697 case CmpInst::ICMP_NE:
1569715698 if (match(RHS, m_scev_Zero())) {
1569815699 const SCEV *OneAlignedUp =
15699- GetNextSCEVDividesByDivisor (One, DividesBy);
15700+ getNextSCEVDividesByDivisor (One, DividesBy, SE );
1570015701 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
1570115702 } else {
1570215703 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
@@ -15922,8 +15923,11 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
1592215923 if (MatchBinarySub(S, LHS, RHS)) {
1592315924 if (LHS > RHS)
1592415925 std::swap(LHS, RHS);
15925- if (NotEqual.contains({LHS, RHS}))
15926- return SE.getUMaxExpr(S, SE.getOne(S->getType()));
15926+ if (NotEqual.contains({LHS, RHS})) {
15927+ const SCEV *OneAlignedUp = getNextSCEVDividesByDivisor(
15928+ SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
15929+ return SE.getUMaxExpr(OneAlignedUp, S);
15930+ }
1592715931 }
1592815932 return nullptr;
1592915933 };
0 commit comments