Skip to content

Commit 235b4e7

Browse files
rscgopherbot
authored andcommitted
cmd/compile/internal/ssa: model right shift more precisely
Prove currently checks for 0 sign bit extraction (x>>63) at the end of the pass, but it is more general and more useful (and not really more work) to model right shift during value range tracking. This handles sign bit extraction (both 0 and -1) but also makes the value ranges available for proving bounds checks. 'go build -a -gcflags=-d=ssa/prove/debug=1 std' finds 105 new things to prove. https://gist.github.com/rsc/8ac41176e53ed9c2f1a664fc668e8336 For example, the compiler now recognizes that this code in strconv does not need to check the second shift for being ≥ 64. msb := xHi >> 63 retMantissa := xHi >> (msb + 38) nor does this code in regexp: return b < utf8.RuneSelf && specialBytes[b%16]&(1<<(b/16)) != 0 This code in math no longer has a bounds check on the first index: if 0 <= n && n <= 308 { return pow10postab32[uint(n)/32] * pow10tab[uint(n)%32] } The diff shows one "lost" proof in ycbcr.go but it's not really lost: the expression was folded to a constant instead, and that only shows up with debug=2. A diff of that output is at https://gist.github.com/rsc/9139ed46c6019ae007f5a1ba4bb3250f Change-Id: I84087311e0a303f00e2820d957a6f8b29ee22519 Reviewed-on: https://go-review.googlesource.com/c/go/+/716140 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Auto-Submit: Russ Cox <rsc@golang.org> Reviewed-by: David Chase <drchase@google.com>
1 parent d44db29 commit 235b4e7

File tree

3 files changed

+188
-54
lines changed

3 files changed

+188
-54
lines changed

src/cmd/compile/internal/ssa/prove.go

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"math"
1313
"math/bits"
1414
"slices"
15+
"strings"
1516
)
1617

1718
type branch int
@@ -132,7 +133,7 @@ type limit struct {
132133
}
133134

134135
func (l limit) String() string {
135-
return fmt.Sprintf("sm,SM,um,UM=%d,%d,%d,%d", l.min, l.max, l.umin, l.umax)
136+
return fmt.Sprintf("sm,SM=%d,%d um,UM=%d,%d", l.min, l.max, l.umin, l.umax)
136137
}
137138

138139
func (l limit) intersect(l2 limit) limit {
@@ -1965,6 +1966,30 @@ func (ft *factsTable) flowLimit(v *Value) bool {
19651966
b := ft.limits[v.Args[1].ID]
19661967
bitsize := uint(v.Type.Size()) * 8
19671968
return ft.newLimit(v, a.mul(b.exp2(bitsize), bitsize))
1969+
case OpRsh64x64, OpRsh64x32, OpRsh64x16, OpRsh64x8,
1970+
OpRsh32x64, OpRsh32x32, OpRsh32x16, OpRsh32x8,
1971+
OpRsh16x64, OpRsh16x32, OpRsh16x16, OpRsh16x8,
1972+
OpRsh8x64, OpRsh8x32, OpRsh8x16, OpRsh8x8:
1973+
a := ft.limits[v.Args[0].ID]
1974+
b := ft.limits[v.Args[1].ID]
1975+
if b.min >= 0 {
1976+
// Shift of negative makes a value closer to 0 (greater),
1977+
// so if a.min is negative, v.min is a.min>>b.min instead of a.min>>b.max,
1978+
// and similarly if a.max is negative, v.max is a.max>>b.max.
1979+
// Easier to compute min and max of both than to write sign logic.
1980+
vmin := min(a.min>>b.min, a.min>>b.max)
1981+
vmax := max(a.max>>b.min, a.max>>b.max)
1982+
return ft.signedMinMax(v, vmin, vmax)
1983+
}
1984+
case OpRsh64Ux64, OpRsh64Ux32, OpRsh64Ux16, OpRsh64Ux8,
1985+
OpRsh32Ux64, OpRsh32Ux32, OpRsh32Ux16, OpRsh32Ux8,
1986+
OpRsh16Ux64, OpRsh16Ux32, OpRsh16Ux16, OpRsh16Ux8,
1987+
OpRsh8Ux64, OpRsh8Ux32, OpRsh8Ux16, OpRsh8Ux8:
1988+
a := ft.limits[v.Args[0].ID]
1989+
b := ft.limits[v.Args[1].ID]
1990+
if b.min >= 0 {
1991+
return ft.unsignedMinMax(v, a.umin>>b.max, a.umax>>b.min)
1992+
}
19681993
case OpDiv64, OpDiv32, OpDiv16, OpDiv8:
19691994
a := ft.limits[v.Args[0].ID]
19701995
b := ft.limits[v.Args[1].ID]
@@ -2621,6 +2646,17 @@ var bytesizeToAnd = [...]Op{
26212646
func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) {
26222647
for iv, v := range b.Values {
26232648
switch v.Op {
2649+
case OpStaticLECall:
2650+
if b.Func.pass.debug > 0 && len(v.Args) == 2 {
2651+
fn := auxToCall(v.Aux).Fn
2652+
if fn != nil && strings.Contains(fn.String(), "prove") {
2653+
// Print bounds of any argument to single-arg function with "prove" in name,
2654+
// for debugging and especially for test/prove.go.
2655+
// (v.Args[1] is mem).
2656+
x := v.Args[0]
2657+
b.Func.Warnl(v.Pos, "Proved %v (%v)", ft.limits[x.ID], x)
2658+
}
2659+
}
26242660
case OpSlicemask:
26252661
// Replace OpSlicemask operations in b with constants where possible.
26262662
cap := v.Args[0]
@@ -2670,21 +2706,8 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) {
26702706
case OpRsh8x8, OpRsh8x16, OpRsh8x32, OpRsh8x64,
26712707
OpRsh16x8, OpRsh16x16, OpRsh16x32, OpRsh16x64,
26722708
OpRsh32x8, OpRsh32x16, OpRsh32x32, OpRsh32x64,
2673-
OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64:
2674-
// Check whether, for a >> b, we know that a is non-negative
2675-
// and b is all of a's bits except the MSB. If so, a is shifted to zero.
2676-
bits := 8 * v.Args[0].Type.Size()
2677-
if v.Args[1].isGenericIntConst() && v.Args[1].AuxInt >= bits-1 && ft.isNonNegative(v.Args[0]) {
2678-
if b.Func.pass.debug > 0 {
2679-
b.Func.Warnl(v.Pos, "Proved %v shifts to zero", v.Op)
2680-
}
2681-
v.reset(bytesizeToConst[bits/8])
2682-
v.AuxInt = 0
2683-
break // Be sure not to fallthrough - this is no longer OpRsh.
2684-
}
2685-
// If the Rsh hasn't been replaced with 0, still check if it is bounded.
2686-
fallthrough
2687-
case OpLsh8x8, OpLsh8x16, OpLsh8x32, OpLsh8x64,
2709+
OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64,
2710+
OpLsh8x8, OpLsh8x16, OpLsh8x32, OpLsh8x64,
26882711
OpLsh16x8, OpLsh16x16, OpLsh16x32, OpLsh16x64,
26892712
OpLsh32x8, OpLsh32x16, OpLsh32x32, OpLsh32x64,
26902713
OpLsh64x8, OpLsh64x16, OpLsh64x32, OpLsh64x64,

test/prove.go

Lines changed: 97 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -971,40 +971,6 @@ func negIndex2(n int) {
971971
useSlice(c)
972972
}
973973

974-
// Check that prove is zeroing these right shifts of positive ints by bit-width - 1.
975-
// e.g (Rsh64x64 <t> n (Const64 <typ.UInt64> [63])) && ft.isNonNegative(n) -> 0
976-
func sh64(n int64) int64 {
977-
if n < 0 {
978-
return n
979-
}
980-
return n >> 63 // ERROR "Proved Rsh64x64 shifts to zero"
981-
}
982-
983-
func sh32(n int32) int32 {
984-
if n < 0 {
985-
return n
986-
}
987-
return n >> 31 // ERROR "Proved Rsh32x64 shifts to zero"
988-
}
989-
990-
func sh32x64(n int32) int32 {
991-
if n < 0 {
992-
return n
993-
}
994-
return n >> uint64(31) // ERROR "Proved Rsh32x64 shifts to zero"
995-
}
996-
997-
func sh16(n int16) int16 {
998-
if n < 0 {
999-
return n
1000-
}
1001-
return n >> 15 // ERROR "Proved Rsh16x64 shifts to zero"
1002-
}
1003-
1004-
func sh64noopt(n int64) int64 {
1005-
return n >> 63 // not optimized; n could be negative
1006-
}
1007-
1008974
// These cases are division of a positive signed integer by a power of 2.
1009975
// The opt pass doesnt have sufficient information to see that n is positive.
1010976
// So, instead, opt rewrites the division with a less-than-optimal replacement.
@@ -2584,6 +2550,103 @@ func swapbound(v []int) {
25842550
}
25852551
}
25862552

2553+
func rightshift(v *[256]int) int {
2554+
for i := range 1024 { // ERROR "Induction"
2555+
if v[i/32] == 0 { // ERROR "Proved Div64 is unsigned" "Proved IsInBounds"
2556+
return i
2557+
}
2558+
}
2559+
for i := range 1024 { // ERROR "Induction"
2560+
if v[i>>2] == 0 { // ERROR "Proved IsInBounds"
2561+
return i
2562+
}
2563+
}
2564+
return -1
2565+
}
2566+
2567+
func rightShiftBounds(v, s int) {
2568+
// The ignored "Proved" messages on the shift itself are about whether s >= 0 or s < 32 or 64.
2569+
// We care about the bounds for x printed on the prove(x) lines.
2570+
2571+
if -8 <= v && v <= -2 && 1 <= s && s <= 3 {
2572+
x := v>>s // ERROR "Proved"
2573+
prove(x) // ERROR "Proved sm,SM=-4,-1 "
2574+
}
2575+
if -80 <= v && v <= -20 && 1 <= s && s <= 3 {
2576+
x := v>>s // ERROR "Proved"
2577+
prove(x) // ERROR "Proved sm,SM=-40,-3 "
2578+
}
2579+
if -8 <= v && v <= 10 && 1 <= s && s <= 3 {
2580+
x := v>>s // ERROR "Proved"
2581+
prove(x) // ERROR "Proved sm,SM=-4,5 "
2582+
}
2583+
if 2 <= v && v <= 10 && 1 <= s && s <= 3 {
2584+
x := v>>s // ERROR "Proved"
2585+
prove(x) // ERROR "Proved sm,SM=0,5 "
2586+
}
2587+
2588+
if -8 <= v && v <= -2 && 0 <= s && s <= 3 {
2589+
x := v>>s // ERROR "Proved"
2590+
prove(x) // ERROR "Proved sm,SM=-8,-1 "
2591+
}
2592+
if -80 <= v && v <= -20 && 0 <= s && s <= 3 {
2593+
x := v>>s // ERROR "Proved"
2594+
prove(x) // ERROR "Proved sm,SM=-80,-3 "
2595+
}
2596+
if -8 <= v && v <= 10 && 0 <= s && s <= 3 {
2597+
x := v>>s // ERROR "Proved"
2598+
prove(x) // ERROR "Proved sm,SM=-8,10 "
2599+
}
2600+
if 2 <= v && v <= 10 && 0 <= s && s <= 3 {
2601+
x := v>>s // ERROR "Proved"
2602+
prove(x) // ERROR "Proved sm,SM=0,10 "
2603+
}
2604+
2605+
if -8 <= v && v <= -2 && -1 <= s && s <= 3 {
2606+
x := v>>s // ERROR "Proved"
2607+
prove(x) // ERROR "Proved sm,SM=-8,-1 "
2608+
}
2609+
if -80 <= v && v <= -20 && -1 <= s && s <= 3 {
2610+
x := v>>s // ERROR "Proved"
2611+
prove(x) // ERROR "Proved sm,SM=-80,-3 "
2612+
}
2613+
if -8 <= v && v <= 10 && -1 <= s && s <= 3 {
2614+
x := v>>s // ERROR "Proved"
2615+
prove(x) // ERROR "Proved sm,SM=-8,10 "
2616+
}
2617+
if 2 <= v && v <= 10 && -1 <= s && s <= 3 {
2618+
x := v>>s // ERROR "Proved"
2619+
prove(x) // ERROR "Proved sm,SM=0,10 "
2620+
}
2621+
}
2622+
2623+
func unsignedRightShiftBounds(v uint, s int) {
2624+
if 2 <= v && v <= 10 && -1 <= s && s <= 3 {
2625+
x := v>>s // ERROR "Proved"
2626+
proveu(x) // ERROR "Proved sm,SM=0,10 "
2627+
}
2628+
if 2 <= v && v <= 10 && 0 <= s && s <= 3 {
2629+
x := v>>s // ERROR "Proved"
2630+
proveu(x) // ERROR "Proved sm,SM=0,10 "
2631+
}
2632+
if 2 <= v && v <= 10 && 1 <= s && s <= 3 {
2633+
x := v>>s // ERROR "Proved"
2634+
proveu(x) // ERROR "Proved sm,SM=0,5 "
2635+
}
2636+
if 20 <= v && v <= 100 && 1 <= s && s <= 3 {
2637+
x := v>>s // ERROR "Proved"
2638+
proveu(x) // ERROR "Proved sm,SM=2,50 "
2639+
}
2640+
}
2641+
2642+
//go:noinline
2643+
func prove(x int) {
2644+
}
2645+
2646+
//go:noinline
2647+
func proveu(x uint) {
2648+
}
2649+
25872650
//go:noinline
25882651
func useInt(a int) {
25892652
}

test/prove_constant_folding.go

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,62 @@ func f0i(x int) int {
2020
return x + 1
2121
}
2222

23-
func f0u(x uint) uint {
23+
func f0u(x uint) int {
2424
if x == 20 {
25-
return x // ERROR "Proved.+is constant 20$"
25+
return int(x) // ERROR "Proved.+is constant 20$"
2626
}
2727

2828
if (x + 20) == 20 {
29-
return x + 5 // ERROR "Proved.+is constant 0$" "Proved.+is constant 5$" "x\+d >=? w"
29+
return int(x + 5) // ERROR "Proved.+is constant 0$" "Proved.+is constant 5$" "x\+d >=? w"
3030
}
3131

32-
return x + 1
32+
if x < 1000 {
33+
return int(x)>>31 // ERROR "Proved.+is constant 0$"
34+
}
35+
if x := int32(x); x < -1000 {
36+
return int(x>>31) // ERROR "Proved.+is constant -1$"
37+
}
38+
39+
return int(x) + 1
40+
}
41+
42+
// Check that prove is zeroing these right shifts of positive ints by bit-width - 1.
43+
// e.g (Rsh64x64 <t> n (Const64 <typ.UInt64> [63])) && ft.isNonNegative(n) -> 0
44+
func sh64(n int64) int64 {
45+
if n < 0 {
46+
return n
47+
}
48+
return n >> 63 // ERROR "Proved .+ is constant 0$"
49+
}
50+
51+
func sh32(n int32) int32 {
52+
if n < 0 {
53+
return n
54+
}
55+
return n >> 31 // ERROR "Proved .+ is constant 0$"
56+
}
57+
58+
func sh32x64(n int32) int32 {
59+
if n < 0 {
60+
return n
61+
}
62+
return n >> uint64(31) // ERROR "Proved .+ is constant 0$"
63+
}
64+
65+
func sh32x64n(n int32) int32 {
66+
if n >= 0 {
67+
return 0
68+
}
69+
return n >> 31// ERROR "Proved .+ is constant -1$"
70+
}
71+
72+
func sh16(n int16) int16 {
73+
if n < 0 {
74+
return n
75+
}
76+
return n >> 15 // ERROR "Proved .+ is constant 0$"
77+
}
78+
79+
func sh64noopt(n int64) int64 {
80+
return n >> 63 // not optimized; n could be negative
3381
}

0 commit comments

Comments
 (0)