diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index f144f17d5a8f2..867f30985ad4e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -900,6 +900,41 @@ namespace { ISD::NodeType ExtType); }; +/// Generic remainder optimization : Folds a remainder operation (A % B) by reusing the computed quotient (A / B). +static SDValue PerformREMCombineGeneric(SDNode *N, DAGCombiner &DC, + CodeGenOptLevel OptLevel) { + assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM); + + // Don't do anything at less than -O2. + if (OptLevel < CodeGenOptLevel::Default) + return SDValue(); + + SelectionDAG &DAG = DC.getDAG(); + SDLoc DL(N); + EVT VT = N->getValueType(0); + bool IsSigned = N->getOpcode() == ISD::SREM; + unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV; + + const SDValue &Num = N->getOperand(0); + const SDValue &Den = N->getOperand(1); + + AttributeList Attr = DC.getDAG().getMachineFunction().getFunction().getAttributes(); + if (DC.getDAG().getTargetLoweringInfo().isIntDivCheap(N->getValueType(0), Attr)) + return SDValue(); + + for (const SDNode *U : Num->users()) { + if (U->getOpcode() == DivOpc && U->getOperand(0) == Num && + U->getOperand(1) == Den) { + // Num % Den -> Num - (Num / Den) * Den + return DAG.getNode(ISD::SUB, DL, VT, Num, + DAG.getNode(ISD::MUL, DL, VT, + DAG.getNode(DivOpc, DL, VT, Num, Den), + Den)); + } + } + return SDValue(); +} + /// This class is a DAGUpdateListener that removes any deleted /// nodes from the worklist. class WorklistRemover : public SelectionDAG::DAGUpdateListener { @@ -5400,6 +5435,9 @@ SDValue DAGCombiner::visitREM(SDNode *N) { if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; + if (SDValue V = PerformREMCombineGeneric(N, *this, OptLevel)) + return V; + if (isSigned) { // If we know the sign bits of both operands are zero, strength reduce to a // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15 diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index a3deb36074e68..a3cbb09297f24 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -5726,37 +5726,6 @@ static SDValue PerformFMinMaxCombine(SDNode *N, return SDValue(); } -static SDValue PerformREMCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, - CodeGenOptLevel OptLevel) { - assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM); - - // Don't do anything at less than -O2. - if (OptLevel < CodeGenOptLevel::Default) - return SDValue(); - - SelectionDAG &DAG = DCI.DAG; - SDLoc DL(N); - EVT VT = N->getValueType(0); - bool IsSigned = N->getOpcode() == ISD::SREM; - unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV; - - const SDValue &Num = N->getOperand(0); - const SDValue &Den = N->getOperand(1); - - for (const SDNode *U : Num->users()) { - if (U->getOpcode() == DivOpc && U->getOperand(0) == Num && - U->getOperand(1) == Den) { - // Num % Den -> Num - (Num / Den) * Den - return DAG.getNode(ISD::SUB, DL, VT, Num, - DAG.getNode(ISD::MUL, DL, VT, - DAG.getNode(DivOpc, DL, VT, Num, Den), - Den)); - } - } - return SDValue(); -} - // (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y) static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel) { @@ -6428,9 +6397,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return PerformSETCCCombine(N, DCI, STI.getSmVersion()); case ISD::SHL: return PerformSHLCombine(N, DCI, OptLevel); - case ISD::SREM: - case ISD::UREM: - return PerformREMCombine(N, DCI, OptLevel); case ISD::STORE: case NVPTXISD::StoreV2: case NVPTXISD::StoreV4: