@@ -50823,10 +50823,83 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
5082350823 return SDValue();
5082450824}
5082550825
50826+ static SDValue combineConstantPoolLoads(SDNode *N, const SDLoc &dl,
50827+ SelectionDAG &DAG,
50828+ TargetLowering::DAGCombinerInfo &DCI,
50829+ const X86Subtarget &Subtarget) {
50830+ auto *Ld = cast<LoadSDNode>(N);
50831+ EVT RegVT = Ld->getValueType(0);
50832+ EVT MemVT = Ld->getMemoryVT();
50833+ SDValue Ptr = Ld->getBasePtr();
50834+ SDValue Chain = Ld->getChain();
50835+ ISD::LoadExtType Ext = Ld->getExtensionType();
50836+
50837+ if (Ext != ISD::NON_EXTLOAD || !Subtarget.hasAVX() || !Ld->isSimple())
50838+ return SDValue();
50839+
50840+ if (!(RegVT.is128BitVector() || RegVT.is256BitVector()))
50841+ return SDValue();
50842+
50843+ auto MatchingBits = [](const APInt &Undefs, const APInt &UserUndefs,
50844+ ArrayRef<APInt> Bits, ArrayRef<APInt> UserBits) {
50845+ for (unsigned I = 0, E = Undefs.getBitWidth(); I != E; ++I) {
50846+ if (Undefs[I])
50847+ continue;
50848+ if (UserUndefs[I] || Bits[I] != UserBits[I])
50849+ return false;
50850+ }
50851+ return true;
50852+ };
50853+
50854+ // Look through all other loads/broadcasts in the chain for another constant
50855+ // pool entry.
50856+ for (SDNode *User : Chain->uses()) {
50857+ auto *UserLd = dyn_cast<MemSDNode>(User);
50858+ if (User != N && UserLd &&
50859+ (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
50860+ User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
50861+ ISD::isNormalLoad(User)) &&
50862+ UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
50863+ User->getValueSizeInBits(0).getFixedValue() >
50864+ RegVT.getFixedSizeInBits()) {
50865+ EVT UserVT = User->getValueType(0);
50866+ SDValue UserPtr = UserLd->getBasePtr();
50867+ const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
50868+ const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
50869+
50870+ // See if we are loading a constant that matches in the lower
50871+ // bits of a longer constant (but from a different constant pool ptr).
50872+ if (LdC && UserC && UserPtr != Ptr) {
50873+ unsigned LdSize = LdC->getType()->getPrimitiveSizeInBits();
50874+ unsigned UserSize = UserC->getType()->getPrimitiveSizeInBits();
50875+ if (LdSize < UserSize || !ISD::isNormalLoad(User)) {
50876+ APInt Undefs, UserUndefs;
50877+ SmallVector<APInt> Bits, UserBits;
50878+ unsigned NumBits = std::min(RegVT.getScalarSizeInBits(),
50879+ UserVT.getScalarSizeInBits());
50880+ if (getTargetConstantBitsFromNode(SDValue(N, 0), NumBits, Undefs,
50881+ Bits) &&
50882+ getTargetConstantBitsFromNode(SDValue(User, 0), NumBits,
50883+ UserUndefs, UserBits)) {
50884+ if (MatchingBits(Undefs, UserUndefs, Bits, UserBits)) {
50885+ SDValue Extract = extractSubVector(
50886+ SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
50887+ Extract = DAG.getBitcast(RegVT, Extract);
50888+ return DCI.CombineTo(N, Extract, SDValue(User, 1));
50889+ }
50890+ }
50891+ }
50892+ }
50893+ }
50894+ }
50895+
50896+ return SDValue();
50897+ }
50898+
5082650899static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
5082750900 TargetLowering::DAGCombinerInfo &DCI,
5082850901 const X86Subtarget &Subtarget) {
50829- LoadSDNode *Ld = cast<LoadSDNode>(N);
50902+ auto *Ld = cast<LoadSDNode>(N);
5083050903 EVT RegVT = Ld->getValueType(0);
5083150904 EVT MemVT = Ld->getMemoryVT();
5083250905 SDLoc dl(Ld);
@@ -50885,7 +50958,7 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
5088550958 }
5088650959 }
5088750960
50888- // If we also load/ broadcast this to a wider type, then just extract the
50961+ // If we also broadcast this vector to a wider type, then just extract the
5088950962 // lowest subvector.
5089050963 if (Ext == ISD::NON_EXTLOAD && Subtarget.hasAVX() && Ld->isSimple() &&
5089150964 (RegVT.is128BitVector() || RegVT.is256BitVector())) {
@@ -50894,61 +50967,23 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
5089450967 for (SDNode *User : Chain->uses()) {
5089550968 auto *UserLd = dyn_cast<MemSDNode>(User);
5089650969 if (User != N && UserLd &&
50897- ( User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
50898- User->getOpcode () == X86ISD::VBROADCAST_LOAD ||
50899- ISD::isNormalLoad(User) ) &&
50900- UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
50970+ User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
50971+ UserLd->getChain() == Chain && UserLd->getBasePtr () == Ptr &&
50972+ UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits( ) &&
50973+ !User->hasAnyUseOfValue(1) &&
5090150974 User->getValueSizeInBits(0).getFixedValue() >
5090250975 RegVT.getFixedSizeInBits()) {
50903- if (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
50904- UserLd->getBasePtr() == Ptr &&
50905- UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits()) {
50906- SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
50907- RegVT.getSizeInBits());
50908- Extract = DAG.getBitcast(RegVT, Extract);
50909- return DCI.CombineTo(N, Extract, SDValue(User, 1));
50910- }
50911- auto MatchingBits = [](const APInt &Undefs, const APInt &UserUndefs,
50912- ArrayRef<APInt> Bits, ArrayRef<APInt> UserBits) {
50913- for (unsigned I = 0, E = Undefs.getBitWidth(); I != E; ++I) {
50914- if (Undefs[I])
50915- continue;
50916- if (UserUndefs[I] || Bits[I] != UserBits[I])
50917- return false;
50918- }
50919- return true;
50920- };
50921- // See if we are loading a constant that matches in the lower
50922- // bits of a longer constant (but from a different constant pool ptr).
50923- EVT UserVT = User->getValueType(0);
50924- SDValue UserPtr = UserLd->getBasePtr();
50925- const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
50926- const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
50927- if (LdC && UserC && UserPtr != Ptr) {
50928- unsigned LdSize = LdC->getType()->getPrimitiveSizeInBits();
50929- unsigned UserSize = UserC->getType()->getPrimitiveSizeInBits();
50930- if (LdSize < UserSize || !ISD::isNormalLoad(User)) {
50931- APInt Undefs, UserUndefs;
50932- SmallVector<APInt> Bits, UserBits;
50933- unsigned NumBits = std::min(RegVT.getScalarSizeInBits(),
50934- UserVT.getScalarSizeInBits());
50935- if (getTargetConstantBitsFromNode(SDValue(N, 0), NumBits, Undefs,
50936- Bits) &&
50937- getTargetConstantBitsFromNode(SDValue(User, 0), NumBits,
50938- UserUndefs, UserBits)) {
50939- if (MatchingBits(Undefs, UserUndefs, Bits, UserBits)) {
50940- SDValue Extract = extractSubVector(
50941- SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
50942- Extract = DAG.getBitcast(RegVT, Extract);
50943- return DCI.CombineTo(N, Extract, SDValue(User, 1));
50944- }
50945- }
50946- }
50947- }
50976+ SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, dl,
50977+ RegVT.getSizeInBits());
50978+ Extract = DAG.getBitcast(RegVT, Extract);
50979+ return DCI.CombineTo(N, Extract, SDValue(User, 1));
5094850980 }
5094950981 }
5095050982 }
5095150983
50984+ if (SDValue V = combineConstantPoolLoads(Ld, dl, DAG, DCI, Subtarget))
50985+ return V;
50986+
5095250987 // Cast ptr32 and ptr64 pointers to the default address space before a load.
5095350988 unsigned AddrSpace = Ld->getAddressSpace();
5095450989 if (AddrSpace == X86AS::PTR64 || AddrSpace == X86AS::PTR32_SPTR ||
0 commit comments