Skip to content

Commit 2af416b

Browse files
committed
[SCEV] Move URem matching to ScalarEvolutionPatternMatch.h (llvm#163170)
Move URem matching to ScalarEvolutionPatternMatch.h so it can be re-used together with other matchers. Depends on llvm#163169 PR: llvm#163170 (cherry picked from commit 7f04ee1)
1 parent 198f96b commit 2af416b

File tree

5 files changed

+98
-96
lines changed

5 files changed

+98
-96
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2316,10 +2316,6 @@ class ScalarEvolution {
23162316
/// an add rec on said loop.
23172317
void getUsedLoops(const SCEV *S, SmallPtrSetImpl<const Loop *> &LoopsUsed);
23182318

2319-
/// Try to match the pattern generated by getURemExpr(A, B). If successful,
2320-
/// Assign A and B to LHS and RHS, respectively.
2321-
LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS);
2322-
23232319
/// Look for a SCEV expression with type `SCEVType` and operands `Ops` in
23242320
/// `UniqueSCEVs`. Return if found, else nullptr.
23252321
SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef<const SCEV *> Ops);

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,80 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
215215
return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
216216
}
217217

218+
/// Match unsigned remainder pattern.
219+
/// Matches patterns generated by getURemExpr.
220+
template <typename Op0_t, typename Op1_t> struct SCEVURem_match {
221+
Op0_t Op0;
222+
Op1_t Op1;
223+
ScalarEvolution &SE;
224+
225+
SCEVURem_match(Op0_t Op0, Op1_t Op1, ScalarEvolution &SE)
226+
: Op0(Op0), Op1(Op1), SE(SE) {}
227+
228+
bool match(const SCEV *Expr) const {
229+
if (Expr->getType()->isPointerTy())
230+
return false;
231+
232+
// Try to match 'zext (trunc A to iB) to iY', which is used
233+
// for URem with constant power-of-2 second operands. Make sure the size of
234+
// the operand A matches the size of the whole expressions.
235+
const SCEV *LHS;
236+
if (SCEVPatternMatch::match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
237+
Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
238+
// Bail out if the type of the LHS is larger than the type of the
239+
// expression for now.
240+
if (SE.getTypeSizeInBits(LHS->getType()) >
241+
SE.getTypeSizeInBits(Expr->getType()))
242+
return false;
243+
if (LHS->getType() != Expr->getType())
244+
LHS = SE.getZeroExtendExpr(LHS, Expr->getType());
245+
const SCEV *RHS =
246+
SE.getConstant(APInt(SE.getTypeSizeInBits(Expr->getType()), 1)
247+
<< SE.getTypeSizeInBits(TruncTy));
248+
return Op0.match(LHS) && Op1.match(RHS);
249+
}
250+
const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
251+
if (Add == nullptr || Add->getNumOperands() != 2)
252+
return false;
253+
254+
const SCEV *A = Add->getOperand(1);
255+
const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
256+
257+
if (Mul == nullptr)
258+
return false;
259+
260+
const auto MatchURemWithDivisor = [&](const SCEV *B) {
261+
// (SomeExpr + (-(SomeExpr / B) * B)).
262+
if (Expr == SE.getURemExpr(A, B))
263+
return Op0.match(A) && Op1.match(B);
264+
return false;
265+
};
266+
267+
// (SomeExpr + (-1 * (SomeExpr / B) * B)).
268+
if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
269+
return MatchURemWithDivisor(Mul->getOperand(1)) ||
270+
MatchURemWithDivisor(Mul->getOperand(2));
271+
272+
// (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
273+
if (Mul->getNumOperands() == 2)
274+
return MatchURemWithDivisor(Mul->getOperand(1)) ||
275+
MatchURemWithDivisor(Mul->getOperand(0)) ||
276+
MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(1))) ||
277+
MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(0)));
278+
return false;
279+
}
280+
};
281+
282+
/// Match the mathematical pattern A - (A / B) * B, where A and B can be
283+
/// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
284+
/// for URem with constant power-of-2 second operands. It's not always easy, as
285+
/// A and B can be folded (imagine A is X / 2, and B is 4, A / B becomes X / 8).
286+
template <typename Op0_t, typename Op1_t>
287+
inline SCEVURem_match<Op0_t, Op1_t> m_scev_URem(Op0_t LHS, Op1_t RHS,
288+
ScalarEvolution &SE) {
289+
return SCEVURem_match<Op0_t, Op1_t>(LHS, RHS, SE);
290+
}
291+
218292
inline class_match<const Loop> m_Loop() { return class_match<const Loop>(); }
219293

220294
/// Match an affine SCEVAddRecExpr.

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 18 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,7 +1774,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
17741774
{
17751775
const SCEV *LHS;
17761776
const SCEV *RHS;
1777-
if (matchURem(Op, LHS, RHS))
1777+
if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
17781778
return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
17791779
getZeroExtendExpr(RHS, Ty, Depth + 1));
17801780
}
@@ -2699,17 +2699,12 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
26992699
}
27002700

27012701
// Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2702-
if (Ops.size() == 2) {
2703-
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2704-
if (Mul && Mul->getNumOperands() == 2 &&
2705-
Mul->getOperand(0)->isAllOnesValue()) {
2706-
const SCEV *X;
2707-
const SCEV *Y;
2708-
if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2709-
return getMulExpr(Y, getUDivExpr(X, Y));
2710-
}
2711-
}
2712-
}
2702+
const SCEV *Y;
2703+
if (Ops.size() == 2 &&
2704+
match(Ops[0],
2705+
m_scev_Mul(m_scev_AllOnes(),
2706+
m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2707+
return getMulExpr(Y, getUDivExpr(Ops[1], Y));
27132708

27142709
// Skip past any other cast SCEVs.
27152710
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
@@ -15353,65 +15348,6 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
1535315348
}
1535415349
}
1535515350

15356-
// Match the mathematical pattern A - (A / B) * B, where A and B can be
15357-
// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15358-
// for URem with constant power-of-2 second operands.
15359-
// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15360-
// 4, A / B becomes X / 8).
15361-
bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15362-
const SCEV *&RHS) {
15363-
if (Expr->getType()->isPointerTy())
15364-
return false;
15365-
15366-
// Try to match 'zext (trunc A to iB) to iY', which is used
15367-
// for URem with constant power-of-2 second operands. Make sure the size of
15368-
// the operand A matches the size of the whole expressions.
15369-
if (match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
15370-
Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
15371-
// Bail out if the type of the LHS is larger than the type of the
15372-
// expression for now.
15373-
if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(Expr->getType()))
15374-
return false;
15375-
if (LHS->getType() != Expr->getType())
15376-
LHS = getZeroExtendExpr(LHS, Expr->getType());
15377-
RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
15378-
<< getTypeSizeInBits(TruncTy));
15379-
return true;
15380-
}
15381-
const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15382-
if (Add == nullptr || Add->getNumOperands() != 2)
15383-
return false;
15384-
15385-
const SCEV *A = Add->getOperand(1);
15386-
const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15387-
15388-
if (Mul == nullptr)
15389-
return false;
15390-
15391-
const auto MatchURemWithDivisor = [&](const SCEV *B) {
15392-
// (SomeExpr + (-(SomeExpr / B) * B)).
15393-
if (Expr == getURemExpr(A, B)) {
15394-
LHS = A;
15395-
RHS = B;
15396-
return true;
15397-
}
15398-
return false;
15399-
};
15400-
15401-
// (SomeExpr + (-1 * (SomeExpr / B) * B)).
15402-
if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15403-
return MatchURemWithDivisor(Mul->getOperand(1)) ||
15404-
MatchURemWithDivisor(Mul->getOperand(2));
15405-
15406-
// (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15407-
if (Mul->getNumOperands() == 2)
15408-
return MatchURemWithDivisor(Mul->getOperand(1)) ||
15409-
MatchURemWithDivisor(Mul->getOperand(0)) ||
15410-
MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15411-
MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15412-
return false;
15413-
}
15414-
1541515351
ScalarEvolution::LoopGuards
1541615352
ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1541715353
BasicBlock *Header = L->getHeader();
@@ -15623,20 +15559,18 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1562315559
if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
1562415560
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
1562515561
// explicitly express that.
15626-
const SCEV *URemLHS = nullptr;
15562+
const SCEVUnknown *URemLHS = nullptr;
1562715563
const SCEV *URemRHS = nullptr;
15628-
if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15629-
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15630-
auto I = RewriteMap.find(LHSUnknown);
15631-
const SCEV *RewrittenLHS =
15632-
I != RewriteMap.end() ? I->second : LHSUnknown;
15633-
RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15634-
const auto *Multiple =
15635-
SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15636-
RewriteMap[LHSUnknown] = Multiple;
15637-
ExprsToRewrite.push_back(LHSUnknown);
15638-
return;
15639-
}
15564+
if (match(LHS,
15565+
m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15566+
auto I = RewriteMap.find(URemLHS);
15567+
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15568+
RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15569+
const auto *Multiple =
15570+
SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15571+
RewriteMap[URemLHS] = Multiple;
15572+
ExprsToRewrite.push_back(URemLHS);
15573+
return;
1564015574
}
1564115575
}
1564215576

llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
506506
// Recognize the canonical representation of an unsimplifed urem.
507507
const SCEV *URemLHS = nullptr;
508508
const SCEV *URemRHS = nullptr;
509-
if (SE.matchURem(S, URemLHS, URemRHS)) {
509+
if (match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), SE))) {
510510
Value *LHS = expand(URemLHS);
511511
Value *RHS = expand(URemRHS);
512512
return InsertBinop(Instruction::URem, LHS, RHS, SCEV::FlagAnyWrap,

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "llvm/Analysis/LoopInfo.h"
1212
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
1313
#include "llvm/Analysis/ScalarEvolutionNormalization.h"
14+
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
1415
#include "llvm/Analysis/TargetLibraryInfo.h"
1516
#include "llvm/AsmParser/Parser.h"
1617
#include "llvm/IR/Constants.h"
@@ -26,6 +27,8 @@
2627

2728
namespace llvm {
2829

30+
using namespace SCEVPatternMatch;
31+
2932
// We use this fixture to ensure that we clean up ScalarEvolution before
3033
// deleting the PassManager.
3134
class ScalarEvolutionsTest : public testing::Test {
@@ -64,11 +67,6 @@ static std::optional<APInt> computeConstantDifference(ScalarEvolution &SE,
6467
return SE.computeConstantDifference(LHS, RHS);
6568
}
6669

67-
static bool matchURem(ScalarEvolution &SE, const SCEV *Expr, const SCEV *&LHS,
68-
const SCEV *&RHS) {
69-
return SE.matchURem(Expr, LHS, RHS);
70-
}
71-
7270
static bool isImpliedCond(
7371
ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS,
7472
const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS,
@@ -1524,7 +1522,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
15241522
auto *URemI = getInstructionByName(F, N);
15251523
auto *S = SE.getSCEV(URemI);
15261524
const SCEV *LHS, *RHS;
1527-
EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
1525+
EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
15281526
EXPECT_EQ(LHS, SE.getSCEV(URemI->getOperand(0)));
15291527
EXPECT_EQ(RHS, SE.getSCEV(URemI->getOperand(1)));
15301528
EXPECT_EQ(LHS->getType(), S->getType());
@@ -1537,7 +1535,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
15371535
auto *URem1 = getInstructionByName(F, "rem4");
15381536
auto *S = SE.getSCEV(Ext);
15391537
const SCEV *LHS, *RHS;
1540-
EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
1538+
EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
15411539
EXPECT_NE(LHS, SE.getSCEV(URem1->getOperand(0)));
15421540
// RHS and URem1->getOperand(1) have different widths, so compare the
15431541
// integer values.

0 commit comments

Comments
 (0)