@@ -125,8 +125,6 @@ def doF32FTZ : Predicate<"useF32FTZ()">;
125125def doNoF32FTZ : Predicate<"!useF32FTZ()">;
126126def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
127127
128- def doMulWide : Predicate<"doMulWide">;
129-
130128def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
131129def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
132130def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
@@ -836,36 +834,28 @@ def MULWIDES64 :
836834 BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.s32">;
837835def MULWIDES64Imm :
838836 BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.s32">;
839- def MULWIDES64Imm64 :
840- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.s32">;
841837
842838def MULWIDEU64 :
843839 BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.u32">;
844840def MULWIDEU64Imm :
845841 BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.u32">;
846- def MULWIDEU64Imm64 :
847- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.u32">;
848842
849843def MULWIDES32 :
850844 BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.s16">;
851845def MULWIDES32Imm :
852846 BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.s16">;
853- def MULWIDES32Imm32 :
854- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.s16">;
855847
856848def MULWIDEU32 :
857849 BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.u16">;
858850def MULWIDEU32Imm :
859851 BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.u16">;
860- def MULWIDEU32Imm32 :
861- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.u16">;
862852
863- def SDTMulWide : SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>;
864- def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
865- def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
853+ def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>;
854+ def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative] >;
855+ def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative] >;
866856
867857// Matchers for signed, unsigned mul.wide ISD nodes.
868- let Predicates = [doMulWide ] in {
858+ let Predicates = [hasOptEnabled ] in {
869859 def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
870860 def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
871861 def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
@@ -877,85 +867,6 @@ let Predicates = [doMulWide] in {
877867 def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
878868}
879869
880- // Predicates used for converting some patterns to mul.wide.
881- def SInt32Const : PatLeaf<(imm), [{
882- const APInt &v = N->getAPIntValue();
883- return v.isSignedIntN(32);
884- }]>;
885-
886- def UInt32Const : PatLeaf<(imm), [{
887- const APInt &v = N->getAPIntValue();
888- return v.isIntN(32);
889- }]>;
890-
891- def SInt16Const : PatLeaf<(imm), [{
892- const APInt &v = N->getAPIntValue();
893- return v.isSignedIntN(16);
894- }]>;
895-
896- def UInt16Const : PatLeaf<(imm), [{
897- const APInt &v = N->getAPIntValue();
898- return v.isIntN(16);
899- }]>;
900-
901- def IntConst_0_30 : PatLeaf<(imm), [{
902- // Check if 0 <= v < 31; only then will the result of (x << v) be an int32.
903- const APInt &v = N->getAPIntValue();
904- return v.sge(0) && v.slt(31);
905- }]>;
906-
907- def IntConst_0_14 : PatLeaf<(imm), [{
908- // Check if 0 <= v < 15; only then will the result of (x << v) be an int16.
909- const APInt &v = N->getAPIntValue();
910- return v.sge(0) && v.slt(15);
911- }]>;
912-
913- def SHL2MUL32 : SDNodeXForm<imm, [{
914- const APInt &v = N->getAPIntValue();
915- APInt temp(32, 1);
916- return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i32);
917- }]>;
918-
919- def SHL2MUL16 : SDNodeXForm<imm, [{
920- const APInt &v = N->getAPIntValue();
921- APInt temp(16, 1);
922- return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i16);
923- }]>;
924-
925- // Convert "sign/zero-extend, then shift left by an immediate" to mul.wide.
926- let Predicates = [doMulWide] in {
927- def : Pat<(shl (sext i32:$a), (i32 IntConst_0_30:$b)),
928- (MULWIDES64Imm $a, (SHL2MUL32 $b))>;
929- def : Pat<(shl (zext i32:$a), (i32 IntConst_0_30:$b)),
930- (MULWIDEU64Imm $a, (SHL2MUL32 $b))>;
931-
932- def : Pat<(shl (sext i16:$a), (i16 IntConst_0_14:$b)),
933- (MULWIDES32Imm $a, (SHL2MUL16 $b))>;
934- def : Pat<(shl (zext i16:$a), (i16 IntConst_0_14:$b)),
935- (MULWIDEU32Imm $a, (SHL2MUL16 $b))>;
936-
937- // Convert "sign/zero-extend then multiply" to mul.wide.
938- def : Pat<(mul (sext i32:$a), (sext i32:$b)),
939- (MULWIDES64 $a, $b)>;
940- def : Pat<(mul (sext i32:$a), (i64 SInt32Const:$b)),
941- (MULWIDES64Imm64 $a, (i64 SInt32Const:$b))>;
942-
943- def : Pat<(mul (zext i32:$a), (zext i32:$b)),
944- (MULWIDEU64 $a, $b)>;
945- def : Pat<(mul (zext i32:$a), (i64 UInt32Const:$b)),
946- (MULWIDEU64Imm64 $a, (i64 UInt32Const:$b))>;
947-
948- def : Pat<(mul (sext i16:$a), (sext i16:$b)),
949- (MULWIDES32 $a, $b)>;
950- def : Pat<(mul (sext i16:$a), (i32 SInt16Const:$b)),
951- (MULWIDES32Imm32 $a, (i32 SInt16Const:$b))>;
952-
953- def : Pat<(mul (zext i16:$a), (zext i16:$b)),
954- (MULWIDEU32 $a, $b)>;
955- def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)),
956- (MULWIDEU32Imm32 $a, (i32 UInt16Const:$b))>;
957- }
958-
959870//
960871// Integer multiply-add
961872//
@@ -991,6 +902,39 @@ defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>;
991902defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>;
992903}
993904
905+ multiclass MAD_WIDE<string PtxSuffix, OneUse2 Op, RegTyInfo BigT, RegTyInfo SmallT> {
906+ def rrr:
907+ BasicNVPTXInst<(outs BigT.RC:$dst),
908+ (ins SmallT.RC:$a, SmallT.RC:$b, BigT.RC:$c),
909+ "mad.wide." # PtxSuffix,
910+ [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), BigT.Ty:$c))]>;
911+ def rri:
912+ BasicNVPTXInst<(outs BigT.RC:$dst),
913+ (ins SmallT.RC:$a, SmallT.RC:$b, BigT.Imm:$c),
914+ "mad.wide." # PtxSuffix,
915+ [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), imm:$c))]>;
916+ def rir:
917+ BasicNVPTXInst<(outs BigT.RC:$dst),
918+ (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.RC:$c),
919+ "mad.wide." # PtxSuffix,
920+ [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), BigT.Ty:$c))]>;
921+ def rii:
922+ BasicNVPTXInst<(outs BigT.RC:$dst),
923+ (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.Imm:$c),
924+ "mad.wide." # PtxSuffix,
925+ [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), imm:$c))]>;
926+ }
927+
928+ def mul_wide_unsigned_oneuse : OneUse2<mul_wide_unsigned>;
929+ def mul_wide_signed_oneuse : OneUse2<mul_wide_signed>;
930+
931+ let Predicates = [hasOptEnabled] in {
932+ defm MAD_WIDE_U16 : MAD_WIDE<"u16", mul_wide_unsigned_oneuse, I32RT, I16RT>;
933+ defm MAD_WIDE_S16 : MAD_WIDE<"s16", mul_wide_signed_oneuse, I32RT, I16RT>;
934+ defm MAD_WIDE_U32 : MAD_WIDE<"u32", mul_wide_unsigned_oneuse, I64RT, I32RT>;
935+ defm MAD_WIDE_S32 : MAD_WIDE<"s32", mul_wide_signed_oneuse, I64RT, I32RT>;
936+ }
937+
994938foreach t = [I16RT, I32RT, I64RT] in {
995939 def NEG_S # t.Size :
996940 BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
0 commit comments