@@ -2283,6 +2283,8 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
22832283 ICmpInst::Predicate PredL = LHS->getPredicate (), PredR = RHS->getPredicate ();
22842284 Value *LHS0 = LHS->getOperand (0 ), *RHS0 = RHS->getOperand (0 );
22852285 Value *LHS1 = LHS->getOperand (1 ), *RHS1 = RHS->getOperand (1 );
2286+ auto *LHSC = dyn_cast<ConstantInt>(LHS1);
2287+ auto *RHSC = dyn_cast<ConstantInt>(RHS1);
22862288
22872289 // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3)
22882290 // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3)
@@ -2294,43 +2296,42 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
22942296 // 3) C1 ^ C2 is one-bit mask.
22952297 // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask.
22962298 // This implies all values in the two ranges differ by exactly one bit.
2297- const APInt *LHSVal, *RHSVal;
22982299 if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) &&
2299- PredL == PredR && LHS->getType () == RHS->getType () &&
2300- LHS->getType ()->isIntOrIntVectorTy () && match (LHS1, m_APInt (LHSVal)) &&
2301- match (RHS1, m_APInt (RHSVal)) && *LHSVal == *RHSVal && LHS->hasOneUse () &&
2302- RHS->hasOneUse ()) {
2303- Value *AddOpnd;
2304- const APInt *LAddVal, *RAddVal;
2305- if (match (LHS0, m_Add (m_Value (AddOpnd), m_APInt (LAddVal))) &&
2306- match (RHS0, m_Add (m_Specific (AddOpnd), m_APInt (RAddVal))) &&
2307- LAddVal->ugt (*LHSVal) && RAddVal->ugt (*LHSVal)) {
2308-
2309- APInt DiffC = *LAddVal ^ *RAddVal;
2310- if (DiffC.isPowerOf2 ()) {
2311- const APInt *MaxAddC = nullptr ;
2312- if (LAddVal->ult (*RAddVal))
2313- MaxAddC = RAddVal;
2300+ PredL == PredR && LHSC && RHSC && LHS->hasOneUse () && RHS->hasOneUse () &&
2301+ LHSC->getType () == RHSC->getType () &&
2302+ LHSC->getValue () == (RHSC->getValue ())) {
2303+
2304+ Value *LAddOpnd, *RAddOpnd;
2305+ ConstantInt *LAddC, *RAddC;
2306+ if (match (LHS0, m_Add (m_Value (LAddOpnd), m_ConstantInt (LAddC))) &&
2307+ match (RHS0, m_Add (m_Value (RAddOpnd), m_ConstantInt (RAddC))) &&
2308+ LAddC->getValue ().ugt (LHSC->getValue ()) &&
2309+ RAddC->getValue ().ugt (LHSC->getValue ())) {
2310+
2311+ APInt DiffC = LAddC->getValue () ^ RAddC->getValue ();
2312+ if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2 ()) {
2313+ ConstantInt *MaxAddC = nullptr ;
2314+ if (LAddC->getValue ().ult (RAddC->getValue ()))
2315+ MaxAddC = RAddC;
23142316 else
2315- MaxAddC = LAddVal ;
2317+ MaxAddC = LAddC ;
23162318
2317- APInt RRangeLow = -*RAddVal ;
2318- APInt RRangeHigh = RRangeLow + *LHSVal ;
2319- APInt LRangeLow = -*LAddVal ;
2320- APInt LRangeHigh = LRangeLow + *LHSVal ;
2319+ APInt RRangeLow = -RAddC-> getValue () ;
2320+ APInt RRangeHigh = RRangeLow + LHSC-> getValue () ;
2321+ APInt LRangeLow = -LAddC-> getValue () ;
2322+ APInt LRangeHigh = LRangeLow + LHSC-> getValue () ;
23212323 APInt LowRangeDiff = RRangeLow ^ LRangeLow;
23222324 APInt HighRangeDiff = RRangeHigh ^ LRangeHigh;
23232325 APInt RangeDiff = LRangeLow.sgt (RRangeLow) ? LRangeLow - RRangeLow
23242326 : RRangeLow - LRangeLow;
23252327
23262328 if (LowRangeDiff.isPowerOf2 () && LowRangeDiff == HighRangeDiff &&
2327- RangeDiff.ugt (*LHSVal)) {
2328- Value *NewAnd = Builder.CreateAnd (
2329- AddOpnd, ConstantInt::get (LHS0->getType (), ~DiffC));
2330- Value *NewAdd = Builder.CreateAdd (
2331- NewAnd, ConstantInt::get (LHS0->getType (), *MaxAddC));
2332- return Builder.CreateICmp (LHS->getPredicate (), NewAdd,
2333- ConstantInt::get (LHS0->getType (), *LHSVal));
2329+ RangeDiff.ugt (LHSC->getValue ())) {
2330+ Value *MaskC = ConstantInt::get (LAddC->getType (), ~DiffC);
2331+
2332+ Value *NewAnd = Builder.CreateAnd (LAddOpnd, MaskC);
2333+ Value *NewAdd = Builder.CreateAdd (NewAnd, MaxAddC);
2334+ return Builder.CreateICmp (LHS->getPredicate (), NewAdd, LHSC);
23342335 }
23352336 }
23362337 }
@@ -2416,8 +2417,6 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
24162417 }
24172418
24182419 // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2).
2419- auto *LHSC = dyn_cast<ConstantInt>(LHS1);
2420- auto *RHSC = dyn_cast<ConstantInt>(RHS1);
24212420 if (!LHSC || !RHSC)
24222421 return nullptr ;
24232422
0 commit comments