@@ -1774,7 +1774,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
17741774 {
17751775 const SCEV *LHS;
17761776 const SCEV *RHS;
1777- if (matchURem (Op, LHS, RHS))
1777+ if (match (Op, m_scev_URem(m_SCEV( LHS), m_SCEV( RHS), *this) ))
17781778 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
17791779 getZeroExtendExpr(RHS, Ty, Depth + 1));
17801780 }
@@ -2699,17 +2699,12 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
26992699 }
27002700
27012701 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2702- if (Ops.size() == 2) {
2703- const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2704- if (Mul && Mul->getNumOperands() == 2 &&
2705- Mul->getOperand(0)->isAllOnesValue()) {
2706- const SCEV *X;
2707- const SCEV *Y;
2708- if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2709- return getMulExpr(Y, getUDivExpr(X, Y));
2710- }
2711- }
2712- }
2702+ const SCEV *Y;
2703+ if (Ops.size() == 2 &&
2704+ match(Ops[0],
2705+ m_scev_Mul(m_scev_AllOnes(),
2706+ m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2707+ return getMulExpr(Y, getUDivExpr(Ops[1], Y));
27132708
27142709 // Skip past any other cast SCEVs.
27152710 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
@@ -15353,65 +15348,6 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
1535315348 }
1535415349}
1535515350
15356- // Match the mathematical pattern A - (A / B) * B, where A and B can be
15357- // arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15358- // for URem with constant power-of-2 second operands.
15359- // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15360- // 4, A / B becomes X / 8).
15361- bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15362- const SCEV *&RHS) {
15363- if (Expr->getType()->isPointerTy())
15364- return false;
15365-
15366- // Try to match 'zext (trunc A to iB) to iY', which is used
15367- // for URem with constant power-of-2 second operands. Make sure the size of
15368- // the operand A matches the size of the whole expressions.
15369- if (match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
15370- Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
15371- // Bail out if the type of the LHS is larger than the type of the
15372- // expression for now.
15373- if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(Expr->getType()))
15374- return false;
15375- if (LHS->getType() != Expr->getType())
15376- LHS = getZeroExtendExpr(LHS, Expr->getType());
15377- RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
15378- << getTypeSizeInBits(TruncTy));
15379- return true;
15380- }
15381- const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15382- if (Add == nullptr || Add->getNumOperands() != 2)
15383- return false;
15384-
15385- const SCEV *A = Add->getOperand(1);
15386- const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15387-
15388- if (Mul == nullptr)
15389- return false;
15390-
15391- const auto MatchURemWithDivisor = [&](const SCEV *B) {
15392- // (SomeExpr + (-(SomeExpr / B) * B)).
15393- if (Expr == getURemExpr(A, B)) {
15394- LHS = A;
15395- RHS = B;
15396- return true;
15397- }
15398- return false;
15399- };
15400-
15401- // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15402- if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15403- return MatchURemWithDivisor(Mul->getOperand(1)) ||
15404- MatchURemWithDivisor(Mul->getOperand(2));
15405-
15406- // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15407- if (Mul->getNumOperands() == 2)
15408- return MatchURemWithDivisor(Mul->getOperand(1)) ||
15409- MatchURemWithDivisor(Mul->getOperand(0)) ||
15410- MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15411- MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15412- return false;
15413- }
15414-
1541515351ScalarEvolution::LoopGuards
1541615352ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1541715353 BasicBlock *Header = L->getHeader();
@@ -15623,20 +15559,18 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1562315559 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
1562415560 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
1562515561 // explicitly express that.
15626- const SCEV *URemLHS = nullptr;
15562+ const SCEVUnknown *URemLHS = nullptr;
1562715563 const SCEV *URemRHS = nullptr;
15628- if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15629- if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15630- auto I = RewriteMap.find(LHSUnknown);
15631- const SCEV *RewrittenLHS =
15632- I != RewriteMap.end() ? I->second : LHSUnknown;
15633- RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15634- const auto *Multiple =
15635- SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15636- RewriteMap[LHSUnknown] = Multiple;
15637- ExprsToRewrite.push_back(LHSUnknown);
15638- return;
15639- }
15564+ if (match(LHS,
15565+ m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15566+ auto I = RewriteMap.find(URemLHS);
15567+ const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15568+ RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15569+ const auto *Multiple =
15570+ SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15571+ RewriteMap[URemLHS] = Multiple;
15572+ ExprsToRewrite.push_back(URemLHS);
15573+ return;
1564015574 }
1564115575 }
1564215576
0 commit comments