@@ -15030,10 +15030,18 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
1503015030class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
1503115031 const DenseMap<const SCEV *, const SCEV *> ⤅
1503215032
15033+ SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
15034+
1503315035public:
1503415036 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15035- DenseMap<const SCEV *, const SCEV *> &M)
15036- : SCEVRewriteVisitor(SE), Map(M) {}
15037+ DenseMap<const SCEV *, const SCEV *> &M,
15038+ bool PreserveNUW, bool PreserveNSW)
15039+ : SCEVRewriteVisitor(SE), Map(M) {
15040+ if (PreserveNUW)
15041+ FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15042+ if (PreserveNSW)
15043+ FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15044+ }
1503715045
1503815046 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
1503915047
@@ -15089,6 +15097,36 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
1508915097 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
1509015098 return I->second;
1509115099 }
15100+
15101+ const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15102+ SmallVector<const SCEV *, 2> Operands;
15103+ bool Changed = false;
15104+ for (const auto *Op : Expr->operands()) {
15105+ Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
15106+ Changed |= Op != Operands.back();
15107+ }
15108+ // We are only replacing operands with equivalent values, so transfer the
15109+ // flags from the original expression.
15110+ return !Changed
15111+ ? Expr
15112+ : SE.getAddExpr(Operands, ScalarEvolution::maskFlags(
15113+ Expr->getNoWrapFlags(), FlagMask));
15114+ }
15115+
15116+ const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
15117+ SmallVector<const SCEV *, 2> Operands;
15118+ bool Changed = false;
15119+ for (const auto *Op : Expr->operands()) {
15120+ Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
15121+ Changed |= Op != Operands.back();
15122+ }
15123+ // We are only replacing operands with equivalent values, so transfer the
15124+ // flags from the original expression.
15125+ return !Changed
15126+ ? Expr
15127+ : SE.getMulExpr(Operands, ScalarEvolution::maskFlags(
15128+ Expr->getNoWrapFlags(), FlagMask));
15129+ }
1509215130};
1509315131
1509415132const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
@@ -15503,18 +15541,29 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
1550315541 if (RewriteMap.empty())
1550415542 return Expr;
1550515543
15544+ // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15545+ // the replacement expressions are contained in the ranges of the replaced
15546+ // expressions.
15547+ bool PreserveNUW = true;
15548+ bool PreserveNSW = true;
15549+ for (const SCEV *Expr : ExprsToRewrite) {
15550+ const SCEV *RewriteTo = RewriteMap[Expr];
15551+ PreserveNUW &= getUnsignedRange(Expr).contains(getUnsignedRange(RewriteTo));
15552+ PreserveNSW &= getSignedRange(Expr).contains(getSignedRange(RewriteTo));
15553+ }
15554+
1550615555 // Now that all rewrite information is collect, rewrite the collected
1550715556 // expressions with the information in the map. This applies information to
1550815557 // sub-expressions.
1550915558 if (ExprsToRewrite.size() > 1) {
1551015559 for (const SCEV *Expr : ExprsToRewrite) {
1551115560 const SCEV *RewriteTo = RewriteMap[Expr];
1551215561 RewriteMap.erase(Expr);
15513- SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15562+ SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW,
15563+ PreserveNSW);
1551415564 RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
1551515565 }
1551615566 }
15517-
15518- SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15567+ SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW, PreserveNSW);
1551915568 return Rewriter.visit(Expr);
1552015569}
0 commit comments