diff --git a/src/cmd/compile/internal/riscv64/ssa.go b/src/cmd/compile/internal/riscv64/ssa.go index 9aa77c3d02bd91..81a71cf72b0ec5 100644 --- a/src/cmd/compile/internal/riscv64/ssa.go +++ b/src/cmd/compile/internal/riscv64/ssa.go @@ -294,7 +294,8 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { ssa.OpRISCV64FADDD, ssa.OpRISCV64FSUBD, ssa.OpRISCV64FMULD, ssa.OpRISCV64FDIVD, ssa.OpRISCV64FEQD, ssa.OpRISCV64FNED, ssa.OpRISCV64FLTD, ssa.OpRISCV64FLED, ssa.OpRISCV64FSGNJD, ssa.OpRISCV64MIN, ssa.OpRISCV64MAX, ssa.OpRISCV64MINU, ssa.OpRISCV64MAXU, - ssa.OpRISCV64SH1ADD, ssa.OpRISCV64SH2ADD, ssa.OpRISCV64SH3ADD: + ssa.OpRISCV64SH1ADD, ssa.OpRISCV64SH2ADD, ssa.OpRISCV64SH3ADD, + ssa.OpRISCV64ADDUW, ssa.OpRISCV64SH1ADDUW, ssa.OpRISCV64SH2ADDUW, ssa.OpRISCV64SH3ADDUW: r := v.Reg() r1 := v.Args[0].Reg() r2 := v.Args[1].Reg() @@ -433,7 +434,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { case ssa.OpRISCV64ADDI, ssa.OpRISCV64ADDIW, ssa.OpRISCV64XORI, ssa.OpRISCV64ORI, ssa.OpRISCV64ANDI, ssa.OpRISCV64SLLI, ssa.OpRISCV64SLLIW, ssa.OpRISCV64SRAI, ssa.OpRISCV64SRAIW, ssa.OpRISCV64SRLI, ssa.OpRISCV64SRLIW, ssa.OpRISCV64SLTI, ssa.OpRISCV64SLTIU, - ssa.OpRISCV64RORI, ssa.OpRISCV64RORIW: + ssa.OpRISCV64RORI, ssa.OpRISCV64RORIW, ssa.OpRISCV64SLLIUW: p := s.Prog(v.Op.Asm()) p.From.Type = obj.TYPE_CONST p.From.Offset = v.AuxInt diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules index 31829a5eed7d0f..0a9a21c7eb95bf 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules @@ -838,10 +838,14 @@ // Optimisations for rva22u64 and above. // +// Combine truncate and logic shift left. +(SLLI [i] (MOVWUreg x)) && i < 64 && buildcfg.GORISCV64 >= 22 => (SLLIUW [i] x) + // Combine left shift and addition. -(ADD (SLLI [1] x) y) && buildcfg.GORISCV64 >= 22 => (SH1ADD x y) -(ADD (SLLI [2] x) y) && buildcfg.GORISCV64 >= 22 => (SH2ADD x y) -(ADD (SLLI [3] x) y) && buildcfg.GORISCV64 >= 22 => (SH3ADD x y) +(ADD (MOVWUreg x) y) && buildcfg.GORISCV64 >= 22 => (ADDUW x y) +(ADD (SLLIUW [1] x) y) && buildcfg.GORISCV64 >= 22 => (SH1ADDUW x y) +(ADD (SLLIUW [2] x) y) && buildcfg.GORISCV64 >= 22 => (SH2ADDUW x y) +(ADD (SLLIUW [3] x) y) && buildcfg.GORISCV64 >= 22 => (SH3ADDUW x y) // Integer minimum and maximum. (Min64 x y) && buildcfg.GORISCV64 >= 22 => (MIN x y) diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go index a0e1ab9754d349..edaf7ad93fb07f 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go @@ -151,6 +151,7 @@ func init() { {name: "ADD", argLength: 2, reg: gp21, asm: "ADD", commutative: true}, // arg0 + arg1 {name: "ADDI", argLength: 1, reg: gp11sb, asm: "ADDI", aux: "Int64"}, // arg0 + auxint {name: "ADDIW", argLength: 1, reg: gp11, asm: "ADDIW", aux: "Int64"}, // 32 low bits of arg0 + auxint, sign extended to 64 bits + {name: "ADDUW", argLength: 2, reg: gp21, asm: "ADDUW"}, // add least significant word of arg0 to arg1 {name: "NEG", argLength: 1, reg: gp11, asm: "NEG"}, // -arg0 {name: "NEGW", argLength: 1, reg: gp11, asm: "NEGW"}, // -arg0 of 32 bits, sign extended to 64 bits {name: "SUB", argLength: 2, reg: gp21, asm: "SUB"}, // arg0 - arg1 @@ -222,6 +223,7 @@ func init() { {name: "SRLW", argLength: 2, reg: gp21, asm: "SRLW"}, // arg0 >> (aux1 & 31), logical right shift of 32 bit value, sign extended to 64 bits {name: "SLLI", argLength: 1, reg: gp11, asm: "SLLI", aux: "Int64"}, // arg0 << auxint, shift amount 0-63, logical left shift {name: "SLLIW", argLength: 1, reg: gp11, asm: "SLLIW", aux: "Int64"}, // arg0 << auxint, shift amount 0-31, logical left shift of 32 bit value, sign extended to 64 bits + {name: "SLLIUW", argLength: 1, reg: gp11, asm: "SLLIUW", aux: "Int64"}, // arg0 << auxint, shift amount 0-31, logical left shift of 32 bit value, zero extended to 64 bits {name: "SRAI", argLength: 1, reg: gp11, asm: "SRAI", aux: "Int64"}, // arg0 >> auxint, shift amount 0-63, arithmetic right shift {name: "SRAIW", argLength: 1, reg: gp11, asm: "SRAIW", aux: "Int64"}, // arg0 >> auxint, shift amount 0-31, arithmetic right shift of 32 bit value, sign extended to 64 bits {name: "SRLI", argLength: 1, reg: gp11, asm: "SRLI", aux: "Int64"}, // arg0 >> auxint, shift amount 0-63, logical right shift @@ -231,6 +233,9 @@ func init() { {name: "SH1ADD", argLength: 2, reg: gp21, asm: "SH1ADD"}, // arg0 << 1 + arg1 {name: "SH2ADD", argLength: 2, reg: gp21, asm: "SH2ADD"}, // arg0 << 2 + arg1 {name: "SH3ADD", argLength: 2, reg: gp21, asm: "SH3ADD"}, // arg0 << 3 + arg1 + {name: "SH1ADDUW", argLength: 2, reg: gp21, asm: "SH1ADDUW"}, // shift the least significant word of arg0 left by 1 and add it to arg1 + {name: "SH2ADDUW", argLength: 2, reg: gp21, asm: "SH2ADDUW"}, // shift the least significant word of arg0 left by 2 and add it to arg1 + {name: "SH3ADDUW", argLength: 2, reg: gp21, asm: "SH3ADDUW"}, // shift the least significant word of arg0 left by 3 and add it to arg1 // Bitwise ops {name: "AND", argLength: 2, reg: gp21, asm: "AND", commutative: true}, // arg0 & arg1 diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64latelower.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64latelower.rules index 7acaa2f3fec546..55b69fa1807e9f 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64latelower.rules +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64latelower.rules @@ -23,3 +23,8 @@ (SRAI [0] x) => x (SRLI [0] x) => x (SLLI [0] x) => x + +// Combine left shift and addition. +(ADD (SLLI [1] x) y) && buildcfg.GORISCV64 >= 22 => (SH1ADD x y) +(ADD (SLLI [2] x) y) && buildcfg.GORISCV64 >= 22 => (SH2ADD x y) +(ADD (SLLI [3] x) y) && buildcfg.GORISCV64 >= 22 => (SH3ADD x y) diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 264f4b3bf378f1..7af1f623b16af9 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -2473,6 +2473,7 @@ const ( OpRISCV64ADD OpRISCV64ADDI OpRISCV64ADDIW + OpRISCV64ADDUW OpRISCV64NEG OpRISCV64NEGW OpRISCV64SUB @@ -2526,6 +2527,7 @@ const ( OpRISCV64SRLW OpRISCV64SLLI OpRISCV64SLLIW + OpRISCV64SLLIUW OpRISCV64SRAI OpRISCV64SRAIW OpRISCV64SRLI @@ -2533,6 +2535,9 @@ const ( OpRISCV64SH1ADD OpRISCV64SH2ADD OpRISCV64SH3ADD + OpRISCV64SH1ADDUW + OpRISCV64SH2ADDUW + OpRISCV64SH3ADDUW OpRISCV64AND OpRISCV64ANDN OpRISCV64ANDI @@ -33218,6 +33223,20 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "ADDUW", + argLen: 2, + asm: riscv.AADDUW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + {1, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, { name: "NEG", argLen: 1, @@ -33962,6 +33981,20 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "SLLIUW", + auxType: auxInt64, + argLen: 1, + asm: riscv.ASLLIUW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, { name: "SRAI", auxType: auxInt64, @@ -34060,6 +34093,48 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "SH1ADDUW", + argLen: 2, + asm: riscv.ASH1ADDUW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + {1, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, + { + name: "SH2ADDUW", + argLen: 2, + asm: riscv.ASH2ADDUW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + {1, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, + { + name: "SH3ADDUW", + argLen: 2, + asm: riscv.ASH3ADDUW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + {1, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, { name: "AND", argLen: 2, diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index 52870fe19921ce..b9414d5227c545 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -3315,12 +3315,12 @@ func rewriteValueRISCV64_OpRISCV64ADD(v *Value) bool { } break } - // match: (ADD (SLLI [1] x) y) + // match: (ADD (MOVWUreg x) y) // cond: buildcfg.GORISCV64 >= 22 - // result: (SH1ADD x y) + // result: (ADDUW x y) for { for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { - if v_0.Op != OpRISCV64SLLI || auxIntToInt64(v_0.AuxInt) != 1 { + if v_0.Op != OpRISCV64MOVWUreg { continue } x := v_0.Args[0] @@ -3328,18 +3328,18 @@ func rewriteValueRISCV64_OpRISCV64ADD(v *Value) bool { if !(buildcfg.GORISCV64 >= 22) { continue } - v.reset(OpRISCV64SH1ADD) + v.reset(OpRISCV64ADDUW) v.AddArg2(x, y) return true } break } - // match: (ADD (SLLI [2] x) y) + // match: (ADD (SLLIUW [1] x) y) // cond: buildcfg.GORISCV64 >= 22 - // result: (SH2ADD x y) + // result: (SH1ADDUW x y) for { for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { - if v_0.Op != OpRISCV64SLLI || auxIntToInt64(v_0.AuxInt) != 2 { + if v_0.Op != OpRISCV64SLLIUW || auxIntToInt64(v_0.AuxInt) != 1 { continue } x := v_0.Args[0] @@ -3347,18 +3347,18 @@ func rewriteValueRISCV64_OpRISCV64ADD(v *Value) bool { if !(buildcfg.GORISCV64 >= 22) { continue } - v.reset(OpRISCV64SH2ADD) + v.reset(OpRISCV64SH1ADDUW) v.AddArg2(x, y) return true } break } - // match: (ADD (SLLI [3] x) y) + // match: (ADD (SLLIUW [2] x) y) // cond: buildcfg.GORISCV64 >= 22 - // result: (SH3ADD x y) + // result: (SH2ADDUW x y) for { for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { - if v_0.Op != OpRISCV64SLLI || auxIntToInt64(v_0.AuxInt) != 3 { + if v_0.Op != OpRISCV64SLLIUW || auxIntToInt64(v_0.AuxInt) != 2 { continue } x := v_0.Args[0] @@ -3366,7 +3366,26 @@ func rewriteValueRISCV64_OpRISCV64ADD(v *Value) bool { if !(buildcfg.GORISCV64 >= 22) { continue } - v.reset(OpRISCV64SH3ADD) + v.reset(OpRISCV64SH2ADDUW) + v.AddArg2(x, y) + return true + } + break + } + // match: (ADD (SLLIUW [3] x) y) + // cond: buildcfg.GORISCV64 >= 22 + // result: (SH3ADDUW x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpRISCV64SLLIUW || auxIntToInt64(v_0.AuxInt) != 3 { + continue + } + x := v_0.Args[0] + y := v_1 + if !(buildcfg.GORISCV64 >= 22) { + continue + } + v.reset(OpRISCV64SH3ADDUW) v.AddArg2(x, y) return true } @@ -7163,6 +7182,23 @@ func rewriteValueRISCV64_OpRISCV64SLLI(v *Value) bool { v.AuxInt = int64ToAuxInt(0) return true } + // match: (SLLI [i] (MOVWUreg x)) + // cond: i < 64 && buildcfg.GORISCV64 >= 22 + // result: (SLLIUW [i] x) + for { + i := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVWUreg { + break + } + x := v_0.Args[0] + if !(i < 64 && buildcfg.GORISCV64 >= 22) { + break + } + v.reset(OpRISCV64SLLIUW) + v.AuxInt = int64ToAuxInt(i) + v.AddArg(x) + return true + } return false } func rewriteValueRISCV64_OpRISCV64SLLW(v *Value) bool { diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64latelower.go b/src/cmd/compile/internal/ssa/rewriteRISCV64latelower.go index d2c3a8f73df2e9..aa45b4d6b838d8 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64latelower.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64latelower.go @@ -2,8 +2,12 @@ package ssa +import "internal/buildcfg" + func rewriteValueRISCV64latelower(v *Value) bool { switch v.Op { + case OpRISCV64ADD: + return rewriteValueRISCV64latelower_OpRISCV64ADD(v) case OpRISCV64AND: return rewriteValueRISCV64latelower_OpRISCV64AND(v) case OpRISCV64NOT: @@ -21,6 +25,68 @@ func rewriteValueRISCV64latelower(v *Value) bool { } return false } +func rewriteValueRISCV64latelower_OpRISCV64ADD(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (ADD (SLLI [1] x) y) + // cond: buildcfg.GORISCV64 >= 22 + // result: (SH1ADD x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpRISCV64SLLI || auxIntToInt64(v_0.AuxInt) != 1 { + continue + } + x := v_0.Args[0] + y := v_1 + if !(buildcfg.GORISCV64 >= 22) { + continue + } + v.reset(OpRISCV64SH1ADD) + v.AddArg2(x, y) + return true + } + break + } + // match: (ADD (SLLI [2] x) y) + // cond: buildcfg.GORISCV64 >= 22 + // result: (SH2ADD x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpRISCV64SLLI || auxIntToInt64(v_0.AuxInt) != 2 { + continue + } + x := v_0.Args[0] + y := v_1 + if !(buildcfg.GORISCV64 >= 22) { + continue + } + v.reset(OpRISCV64SH2ADD) + v.AddArg2(x, y) + return true + } + break + } + // match: (ADD (SLLI [3] x) y) + // cond: buildcfg.GORISCV64 >= 22 + // result: (SH3ADD x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpRISCV64SLLI || auxIntToInt64(v_0.AuxInt) != 3 { + continue + } + x := v_0.Args[0] + y := v_1 + if !(buildcfg.GORISCV64 >= 22) { + continue + } + v.reset(OpRISCV64SH3ADD) + v.AddArg2(x, y) + return true + } + break + } + return false +} func rewriteValueRISCV64latelower_OpRISCV64AND(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] diff --git a/test/codegen/arithmetic.go b/test/codegen/arithmetic.go index 42d5d2ef65848b..ee51e637919ce8 100644 --- a/test/codegen/arithmetic.go +++ b/test/codegen/arithmetic.go @@ -220,6 +220,12 @@ func NegToInt32(a int) int { return r } +func AddWithLeastSignificantWord(a uint64, b int64) uint64 { + // riscv64/rva20u64:"MOVWU" "ADD" + // riscv64/rva22u64,riscv64/rva23u64:"ADDUW" + return a + uint64(uint32(b)) +} + // -------------------- // // Multiplication // // -------------------- // diff --git a/test/codegen/shift.go b/test/codegen/shift.go index 1877247af4d820..db47e3f9740d83 100644 --- a/test/codegen/shift.go +++ b/test/codegen/shift.go @@ -623,6 +623,19 @@ func checkLeftShiftWithAddition(a int64, b int64) int64 { return a } +func checkLeftShiftLeastSignificantWordWithAddition(a uint64, b []int64) uint64 { + // riscv64/rva20u64: "SLLI" "SRLI" "ADD" + // riscv64/rva22u64,riscv64/rva23u64: "SH1ADDUW" + x := a + uint64(uint32(b[0]))<<1 + // riscv64/rva20u64: "SLLI" "SRLI" "ADD" + // riscv64/rva22u64,riscv64/rva23u64: "SH2ADDUW" + y := a + uint64(uint32(b[1]))<<2 + // riscv64/rva20u64: "SLLI" "SRLI" "ADD" + // riscv64/rva22u64,riscv64/rva23u64: "SH3ADDUW" + z := a + uint64(uint32(b[2]))<<3 + return x + y + z +} + // // Convert and shift. // @@ -687,6 +700,12 @@ func rsh64to8(v int64) int8 { return x } +func lsh32Uto64U(a int64) uint64 { + // riscv64/rva20u64:"SLLI" "SRLI" + // riscv64/rva22u64,riscv64/rva23u64:"SLLIUW" + return uint64(uint32(a)) << 6 +} + // We don't need to worry about shifting // more than the type size. // (There is still a negative shift test, but