Skip to content

Commit 839d88e

Browse files
adonovangopherbot
authored andcommitted
go/analysis/passes/hostport: fix four bugs
All arise from assuming that just because the program is well typed, that the fmt.Sprintf call is valid, specifically: 1. don't assume sufficient call arguments; 2. don't assume portExpr.String() == Format(portExpr); 3. don't assume that if port is a constant, it's an integer. 4. don't transform a constant port kPort into a string "123" since it loses the connection to the symbolic constant. + tests of all four Change-Id: I7c6e448efe30065d1c82ed144382788e16c70acf Reviewed-on: https://go-review.googlesource.com/c/tools/+/702895 Reviewed-by: Robert Findley <rfindley@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Auto-Submit: Alan Donovan <adonovan@google.com>
1 parent ba63d13 commit 839d88e

File tree

3 files changed

+69
-16
lines changed

3 files changed

+69
-16
lines changed

go/analysis/passes/hostport/hostport.go

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import (
1010
"fmt"
1111
"go/ast"
1212
"go/constant"
13+
"go/token"
1314
"go/types"
15+
"strconv"
1416

1517
"golang.org/x/tools/go/analysis"
1618
"golang.org/x/tools/go/analysis/passes/inspect"
@@ -57,13 +59,16 @@ func run(pass *analysis.Pass) (any, error) {
5759
}
5860

5961
// checkAddr reports a diagnostic (and returns true) if e
60-
// is a call of the form fmt.Sprintf("%d:%d", ...).
62+
// is a call of the form fmt.Sprintf("%s:%d", ...).
6163
// The diagnostic includes a fix.
6264
//
6365
// dialCall is non-nil if the Dial call is non-local
6466
// but within the same file.
6567
checkAddr := func(e ast.Expr, dialCall *ast.CallExpr) {
66-
if call, ok := e.(*ast.CallExpr); ok && typeutil.Callee(info, call) == fmtSprintf {
68+
if call, ok := e.(*ast.CallExpr); ok &&
69+
len(call.Args) == 3 &&
70+
typeutil.Callee(info, call) == fmtSprintf {
71+
6772
// Examine format string.
6873
formatArg := call.Args[0]
6974
if tv := info.Types[formatArg]; tv.Value != nil {
@@ -99,21 +104,41 @@ func run(pass *analysis.Pass) (any, error) {
99104

100105
// Turn numeric port into a string.
101106
if numericPort {
102-
// port => fmt.Sprintf("%d", port)
103-
// 123 => "123"
104107
port := call.Args[2]
105-
newPort := fmt.Sprintf(`fmt.Sprintf("%%d", %s)`, port)
106-
if port := info.Types[port].Value; port != nil {
107-
if i, ok := constant.Int64Val(port); ok {
108-
newPort = fmt.Sprintf(`"%d"`, i) // numeric constant
108+
109+
// Is port an integer literal?
110+
//
111+
// (Don't allow arbitrary constants k otherwise the
112+
// transformation k => fmt.Sprintf("%d", "123")
113+
// loses the symbolic connection to k.)
114+
var kPort int64 = -1
115+
if lit, ok := port.(*ast.BasicLit); ok && lit.Kind == token.INT {
116+
if v, err := strconv.ParseInt(lit.Value, 0, 64); err == nil {
117+
kPort = v
109118
}
110119
}
111-
112-
edits = append(edits, analysis.TextEdit{
113-
Pos: port.Pos(),
114-
End: port.End(),
115-
NewText: []byte(newPort),
116-
})
120+
if kPort >= 0 {
121+
// literal: 0x7B => "123"
122+
edits = append(edits, analysis.TextEdit{
123+
Pos: port.Pos(),
124+
End: port.End(),
125+
NewText: fmt.Appendf(nil, `"%d"`, kPort), // (decimal)
126+
})
127+
} else {
128+
// non-literal: port => fmt.Sprintf("%d", port)
129+
edits = append(edits, []analysis.TextEdit{
130+
{
131+
Pos: port.Pos(),
132+
End: port.Pos(),
133+
NewText: []byte(`fmt.Sprintf("%d", `),
134+
},
135+
{
136+
Pos: port.End(),
137+
End: port.End(),
138+
NewText: []byte(`)`),
139+
},
140+
}...)
141+
}
117142
}
118143

119144
// Refer to Dial call, if not adjacent.

go/analysis/passes/hostport/testdata/src/a/a.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func direct(host string, port int, portStr string) {
1212
net.Dial("tcp", fmt.Sprintf("%s:%s", host, portStr)) // want `address format "%s:%s" does not work with IPv6`
1313
}
1414

15-
// port is a constant:
15+
// port is a literal:
1616
var addr4 = fmt.Sprintf("%s:%d", "localhost", 123) // want `address format "%s:%d" does not work with IPv6 \(passed to net.Dial at L39\)`
1717

1818
func indirect(host string, port int) {
@@ -38,3 +38,17 @@ func indirect(host string, port int) {
3838
// Dialer.Dial again, addr is declared at package level.
3939
dialer.Dial("tcp", addr4)
4040
}
41+
42+
// Regression tests for crashes in well-typed code that nonetheless mis-uses Sprintf:
43+
// too few arguments, or port is not an integer.
44+
var (
45+
_, _ = net.Dial("tcp", fmt.Sprintf("%s:%d"))
46+
_, _ = net.Dial("tcp", fmt.Sprintf("%s:%d", "host"))
47+
_, _ = net.Dial("tcp", fmt.Sprintf("%s:%d", "host", "port")) // want `address format "%s:%d" does not work with IPv6`
48+
)
49+
50+
func _() {
51+
// port is a non-constant literal
52+
const port = 0x7B
53+
_, _ = net.Dial("tcp", fmt.Sprintf("%s:%d", "localhost", port)) // want `address format "%s:%d" does not work with IPv6`
54+
}

go/analysis/passes/hostport/testdata/src/a/a.go.golden

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func direct(host string, port int, portStr string) {
1212
net.Dial("tcp", net.JoinHostPort(host, portStr)) // want `address format "%s:%s" does not work with IPv6`
1313
}
1414

15-
// port is a constant:
15+
// port is a literal:
1616
var addr4 = net.JoinHostPort("localhost", "123") // want `address format "%s:%d" does not work with IPv6 \(passed to net.Dial at L39\)`
1717

1818
func indirect(host string, port int) {
@@ -38,3 +38,17 @@ func indirect(host string, port int) {
3838
// Dialer.Dial again, addr is declared at package level.
3939
dialer.Dial("tcp", addr4)
4040
}
41+
42+
// Regression tests for crashes in well-typed code that nonetheless mis-uses Sprintf:
43+
// too few arguments, or port is not an integer.
44+
var (
45+
_, _ = net.Dial("tcp", fmt.Sprintf("%s:%d"))
46+
_, _ = net.Dial("tcp", fmt.Sprintf("%s:%d", "host"))
47+
_, _ = net.Dial("tcp", net.JoinHostPort("host", fmt.Sprintf("%d", "port"))) // want `address format "%s:%d" does not work with IPv6`
48+
)
49+
50+
func _() {
51+
// port is a non-constant literal
52+
const port = 0x7B
53+
_, _ = net.Dial("tcp", net.JoinHostPort("localhost", fmt.Sprintf("%d", port))) // want `address format "%s:%d" does not work with IPv6`
54+
}

0 commit comments

Comments
 (0)