@@ -5412,20 +5412,15 @@ static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
54125412 if (SourceBits != NewBits)
54135413 return nullptr;
54145414
5415- const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5416- const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5417- if (!SExt && !ZExt)
5418- return nullptr;
5419- const SCEVTruncateExpr *Trunc =
5420- SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5421- : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5422- if (!Trunc)
5423- return nullptr;
5424- const SCEV *X = Trunc->getOperand();
5425- if (X != SymbolicPHI)
5426- return nullptr;
5427- Signed = SExt != nullptr;
5428- return Trunc->getType();
5415+ if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5416+ Signed = true;
5417+ return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5418+ }
5419+ if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5420+ Signed = false;
5421+ return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5422+ }
5423+ return nullptr;
54295424}
54305425
54315426static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
@@ -15371,20 +15366,18 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
1537115366 // Try to match 'zext (trunc A to iB) to iY', which is used
1537215367 // for URem with constant power-of-2 second operands. Make sure the size of
1537315368 // the operand A matches the size of the whole expressions.
15374- if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15375- if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15376- LHS = Trunc->getOperand();
15377- // Bail out if the type of the LHS is larger than the type of the
15378- // expression for now.
15379- if (getTypeSizeInBits(LHS->getType()) >
15380- getTypeSizeInBits(Expr->getType()))
15381- return false;
15382- if (LHS->getType() != Expr->getType())
15383- LHS = getZeroExtendExpr(LHS, Expr->getType());
15384- RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
15385- << getTypeSizeInBits(Trunc->getType()));
15386- return true;
15387- }
15369+ if (match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
15370+ Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
15371+ // Bail out if the type of the LHS is larger than the type of the
15372+ // expression for now.
15373+ if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(Expr->getType()))
15374+ return false;
15375+ if (LHS->getType() != Expr->getType())
15376+ LHS = getZeroExtendExpr(LHS, Expr->getType());
15377+ RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
15378+ << getTypeSizeInBits(TruncTy));
15379+ return true;
15380+ }
1538815381 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
1538915382 if (Add == nullptr || Add->getNumOperands() != 2)
1539015383 return false;
0 commit comments