Skip to content

Commit 9447ff9

Browse files
cuishuanggopherbot
authored andcommitted
go/analysis/passes/modernize: directly remove user-defined min/max functions
If the parameters, return values, and logic of the user-defined min/max are identical to those of the standard library min/max, then remove the user-defined functions. Change-Id: I881e463489e963f4eb033188e77ee205675d0738 Reviewed-on: https://go-review.googlesource.com/c/tools/+/707915 Reviewed-by: Alan Donovan <adonovan@google.com> Auto-Submit: Alan Donovan <adonovan@google.com> Reviewed-by: Carlos Amedee <carlos@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
1 parent 1605eae commit 9447ff9

File tree

10 files changed

+401
-3
lines changed

10 files changed

+401
-3
lines changed

go/analysis/passes/modernize/minmax.go

Lines changed: 166 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ import (
1717
"golang.org/x/tools/go/ast/inspector"
1818
"golang.org/x/tools/internal/analysisinternal"
1919
"golang.org/x/tools/internal/analysisinternal/generated"
20+
typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex"
2021
"golang.org/x/tools/internal/typeparams"
22+
"golang.org/x/tools/internal/typesinternal/typeindex"
2123
)
2224

2325
var MinMaxAnalyzer = &analysis.Analyzer{
@@ -26,14 +28,16 @@ var MinMaxAnalyzer = &analysis.Analyzer{
2628
Requires: []*analysis.Analyzer{
2729
generated.Analyzer,
2830
inspect.Analyzer,
31+
typeindexanalyzer.Analyzer,
2932
},
3033
Run: minmax,
3134
URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#minmax",
3235
}
3336

34-
// The minmax pass replaces if/else statements with calls to min or max.
37+
// The minmax pass replaces if/else statements with calls to min or max,
38+
// and removes user-defined min/max functions that are equivalent to built-ins.
3539
//
36-
// Patterns:
40+
// If/else replacement patterns:
3741
//
3842
// 1. if a < b { x = a } else { x = b } => x = min(a, b)
3943
// 2. x = a; if a < b { x = b } => x = max(a, b)
@@ -42,13 +46,20 @@ var MinMaxAnalyzer = &analysis.Analyzer{
4246
// is not Nan. Since this is hard to prove, we reject floating-point
4347
// numbers.
4448
//
49+
// Function removal:
50+
// User-defined min/max functions are suggested for removal if they may
51+
// be safely replaced by their built-in namesake.
52+
//
4553
// Variants:
4654
// - all four ordered comparisons
4755
// - "x := a" or "x = a" or "var x = a" in pattern 2
4856
// - "x < b" or "a < b" in pattern 2
4957
func minmax(pass *analysis.Pass) (any, error) {
5058
skipGenerated(pass)
5159

60+
// Check for user-defined min/max functions that can be removed
61+
checkUserDefinedMinMax(pass)
62+
5263
// check is called for all statements of this form:
5364
// if a < b { lhs = rhs }
5465
check := func(file *ast.File, curIfStmt inspector.Cursor, compare *ast.BinaryExpr) {
@@ -275,6 +286,144 @@ func maybeNaN(t types.Type) bool {
275286
return false
276287
}
277288

289+
// checkUserDefinedMinMax looks for user-defined min/max functions that are
290+
// equivalent to the built-in functions and suggests removing them.
291+
func checkUserDefinedMinMax(pass *analysis.Pass) {
292+
index := pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
293+
294+
// Look up min and max functions by name in package scope
295+
for _, funcName := range []string{"min", "max"} {
296+
if fn, ok := pass.Pkg.Scope().Lookup(funcName).(*types.Func); ok {
297+
// Use typeindex to get the FuncDecl directly
298+
if def, ok := index.Def(fn); ok {
299+
decl := def.Parent().Node().(*ast.FuncDecl)
300+
// Check if this function matches the built-in min/max signature and behavior
301+
if canUseBuiltinMinMax(fn, decl.Body) {
302+
// Expand to include leading doc comment
303+
pos := decl.Pos()
304+
if docs := docComment(decl); docs != nil {
305+
pos = docs.Pos()
306+
}
307+
308+
pass.Report(analysis.Diagnostic{
309+
Pos: decl.Pos(),
310+
End: decl.End(),
311+
Message: fmt.Sprintf("user-defined %s function is equivalent to built-in %s and can be removed", funcName, funcName),
312+
SuggestedFixes: []analysis.SuggestedFix{{
313+
Message: fmt.Sprintf("Remove user-defined %s function", funcName),
314+
TextEdits: []analysis.TextEdit{{
315+
Pos: pos,
316+
End: decl.End(),
317+
}},
318+
}},
319+
})
320+
}
321+
}
322+
}
323+
}
324+
}
325+
326+
// canUseBuiltinMinMax reports whether it is safe to replace a call
327+
// to this min or max function by its built-in namesake.
328+
func canUseBuiltinMinMax(fn *types.Func, body *ast.BlockStmt) bool {
329+
sig := fn.Type().(*types.Signature)
330+
331+
// Only consider the most common case: exactly 2 parameters
332+
if sig.Params().Len() != 2 {
333+
return false
334+
}
335+
336+
// Check if any parameter might be floating-point
337+
for param := range sig.Params().Variables() {
338+
if maybeNaN(param.Type()) {
339+
return false // Don't suggest removal for float types due to NaN handling
340+
}
341+
}
342+
343+
// Must have exactly one return value
344+
if sig.Results().Len() != 1 {
345+
return false
346+
}
347+
348+
// Check that the function body implements the expected min/max logic
349+
if body == nil {
350+
return false
351+
}
352+
353+
return hasMinMaxLogic(body, fn.Name())
354+
}
355+
356+
// hasMinMaxLogic checks if the function body implements simple min/max logic.
357+
func hasMinMaxLogic(body *ast.BlockStmt, funcName string) bool {
358+
// Pattern 1: Single if/else statement
359+
if len(body.List) == 1 {
360+
if ifStmt, ok := body.List[0].(*ast.IfStmt); ok {
361+
// Get the "false" result from the else block
362+
if elseBlock, ok := ifStmt.Else.(*ast.BlockStmt); ok && len(elseBlock.List) == 1 {
363+
if elseRet, ok := elseBlock.List[0].(*ast.ReturnStmt); ok && len(elseRet.Results) == 1 {
364+
return checkMinMaxPattern(ifStmt, elseRet.Results[0], funcName)
365+
}
366+
}
367+
}
368+
}
369+
370+
// Pattern 2: if statement followed by return
371+
if len(body.List) == 2 {
372+
if ifStmt, ok := body.List[0].(*ast.IfStmt); ok && ifStmt.Else == nil {
373+
if retStmt, ok := body.List[1].(*ast.ReturnStmt); ok && len(retStmt.Results) == 1 {
374+
return checkMinMaxPattern(ifStmt, retStmt.Results[0], funcName)
375+
}
376+
}
377+
}
378+
379+
return false
380+
}
381+
382+
// checkMinMaxPattern checks if an if statement implements min/max logic.
383+
// ifStmt: the if statement to check
384+
// falseResult: the expression returned when the condition is false
385+
// funcName: "min" or "max"
386+
func checkMinMaxPattern(ifStmt *ast.IfStmt, falseResult ast.Expr, funcName string) bool {
387+
// Must have condition with comparison
388+
cmp, ok := ifStmt.Cond.(*ast.BinaryExpr)
389+
if !ok {
390+
return false
391+
}
392+
393+
// Check if then branch returns one of the compared values
394+
if len(ifStmt.Body.List) != 1 {
395+
return false
396+
}
397+
398+
thenRet, ok := ifStmt.Body.List[0].(*ast.ReturnStmt)
399+
if !ok || len(thenRet.Results) != 1 {
400+
return false
401+
}
402+
403+
// Use the same logic as the existing minmax analyzer
404+
sign := isInequality(cmp.Op)
405+
if sign == 0 {
406+
return false // Not a comparison operator
407+
}
408+
409+
t := thenRet.Results[0] // "true" result
410+
f := falseResult // "false" result
411+
x := cmp.X // left operand
412+
y := cmp.Y // right operand
413+
414+
// Check operand order and adjust sign accordingly
415+
if equalSyntax(t, x) && equalSyntax(f, y) {
416+
sign = +sign
417+
} else if equalSyntax(t, y) && equalSyntax(f, x) {
418+
sign = -sign
419+
} else {
420+
return false
421+
}
422+
423+
// Check if the sign matches the function name
424+
return cond(sign < 0, "min", "max") == funcName
425+
}
426+
278427
// -- utils --
279428

280429
func is[T any](x any) bool {
@@ -289,3 +438,18 @@ func cond[T any](cond bool, t, f T) T {
289438
return f
290439
}
291440
}
441+
442+
// docComment returns the doc comment for a node, if any.
443+
func docComment(n ast.Node) *ast.CommentGroup {
444+
switch n := n.(type) {
445+
case *ast.FuncDecl:
446+
return n.Doc
447+
case *ast.GenDecl:
448+
return n.Doc
449+
case *ast.ValueSpec:
450+
return n.Doc
451+
case *ast.TypeSpec:
452+
return n.Doc
453+
}
454+
return nil // includes File, ImportSpec, Field
455+
}

go/analysis/passes/modernize/modernize_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func TestMapsLoop(t *testing.T) {
3636
}
3737

3838
func TestMinMax(t *testing.T) {
39-
RunWithSuggestedFixes(t, TestData(), modernize.MinMaxAnalyzer, "minmax")
39+
RunWithSuggestedFixes(t, TestData(), modernize.MinMaxAnalyzer, "minmax", "minmax/userdefined", "minmax/wrongoperators", "minmax/nonstrict", "minmax/wrongreturn")
4040
}
4141

4242
func TestOmitZero(t *testing.T) {
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package nonstrict
2+
3+
// min with <= operator - should be detected and removed
4+
func min(a, b int) int { // want "user-defined min function is equivalent to built-in min and can be removed"
5+
if a <= b {
6+
return a
7+
} else {
8+
return b
9+
}
10+
}
11+
12+
// max with >= operator - should be detected and removed
13+
func max(a, b int) int { // want "user-defined max function is equivalent to built-in max and can be removed"
14+
if a >= b {
15+
return a
16+
}
17+
return b
18+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
package nonstrict
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package userdefined
2+
3+
// User-defined min with float parameters - should NOT be removed due to NaN handling
4+
func minFloat(a, b float64) float64 {
5+
if a < b {
6+
return a
7+
} else {
8+
return b
9+
}
10+
}
11+
12+
// User-defined max with float parameters - should NOT be removed due to NaN handling
13+
func maxFloat(a, b float64) float64 {
14+
if a > b {
15+
return a
16+
} else {
17+
return b
18+
}
19+
}
20+
21+
// User-defined function with different name - should NOT be removed
22+
func minimum(a, b int) int {
23+
if a < b {
24+
return a
25+
} else {
26+
return b
27+
}
28+
}
29+
30+
// User-defined min with different logic - should NOT be removed
31+
func minDifferent(a, b int) int {
32+
return a + b // Completely different logic
33+
}
34+
35+
// Method on a type - should NOT be removed
36+
type MyType struct{}
37+
38+
func (m MyType) min(a, b int) int {
39+
if a < b {
40+
return a
41+
} else {
42+
return b
43+
}
44+
}
45+
46+
// Function with wrong signature - should NOT be removed
47+
func minWrongSig(a int) int {
48+
return a
49+
}
50+
51+
// Function with complex body that doesn't match pattern - should NOT be removed
52+
func minComplex(a, b int) int {
53+
println("choosing min")
54+
if a < b {
55+
return a
56+
} else {
57+
return b
58+
}
59+
}
60+
61+
// min returns the smaller of two values.
62+
func min(a, b int) int { // want "user-defined min function is equivalent to built-in min and can be removed"
63+
if a < b {
64+
return a
65+
} else {
66+
return b
67+
}
68+
}
69+
70+
// max returns the larger of two values.
71+
func max(a, b int) int { // want "user-defined max function is equivalent to built-in max and can be removed"
72+
if a > b {
73+
return a
74+
}
75+
return b
76+
}
77+
78+
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package userdefined
2+
3+
// User-defined min with float parameters - should NOT be removed due to NaN handling
4+
func minFloat(a, b float64) float64 {
5+
if a < b {
6+
return a
7+
} else {
8+
return b
9+
}
10+
}
11+
12+
// User-defined max with float parameters - should NOT be removed due to NaN handling
13+
func maxFloat(a, b float64) float64 {
14+
if a > b {
15+
return a
16+
} else {
17+
return b
18+
}
19+
}
20+
21+
// User-defined function with different name - should NOT be removed
22+
func minimum(a, b int) int {
23+
if a < b {
24+
return a
25+
} else {
26+
return b
27+
}
28+
}
29+
30+
// User-defined min with different logic - should NOT be removed
31+
func minDifferent(a, b int) int {
32+
return a + b // Completely different logic
33+
}
34+
35+
// Method on a type - should NOT be removed
36+
type MyType struct{}
37+
38+
func (m MyType) min(a, b int) int {
39+
if a < b {
40+
return a
41+
} else {
42+
return b
43+
}
44+
}
45+
46+
// Function with wrong signature - should NOT be removed
47+
func minWrongSig(a int) int {
48+
return a
49+
}
50+
51+
// Function with complex body that doesn't match pattern - should NOT be removed
52+
func minComplex(a, b int) int {
53+
println("choosing min")
54+
if a < b {
55+
return a
56+
} else {
57+
return b
58+
}
59+
}
60+
61+

0 commit comments

Comments
 (0)