@@ -14,12 +14,12 @@ import (
1414 "go/ast"
1515 "go/token"
1616 "go/types"
17- "html/template"
1817 "os"
1918 "path/filepath"
2019 "sort"
2120 "strconv"
2221 "strings"
22+ "text/template"
2323 "unicode"
2424
2525 "golang.org/x/tools/go/ast/astutil"
@@ -34,44 +34,34 @@ import (
3434
3535const testTmplString = `
3636func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
37- {{- /* Constructor input parameters struct declaration. */}}
38- {{- if and .Receiver .Receiver.Constructor}}
39- {{- if gt (len .Receiver.Constructor.Args) 1}}
40- type constructorArgs struct {
41- {{- range .Receiver.Constructor.Args}}
42- {{.Name}} {{.Type}}
43- {{- end}}
44- }
45- {{- end}}
46- {{- end}}
47-
48- {{- /* Functions/methods input parameters struct declaration. */}}
49- {{- if gt (len .Func.Args) 1}}
50- type args struct {
51- {{- range .Func.Args}}
52- {{.Name}} {{.Type}}
53- {{- end}}
54- }
55- {{- end}}
56-
5737 {{- /* Test cases struct declaration and empty initialization. */}}
5838 tests := []struct {
5939 name string // description of this test case
40+
41+ {{- $commentPrinted := false }}
6042 {{- if and .Receiver .Receiver.Constructor}}
61- {{- if gt (len .Receiver.Constructor.Args) 1}}
62- constructorArgs constructorArgs
43+ {{- range .Receiver.Constructor.Args}}
44+ {{- if .Name}}
45+ {{- if not $commentPrinted}}
46+ // Named input parameters for receiver constructor.
47+ {{- $commentPrinted = true }}
48+ {{- end}}
49+ {{.Name}} {{.Type}}
6350 {{- end}}
64- {{- if eq (len .Receiver.Constructor.Args) 1}}
65- constructorArg {{(index .Receiver.Constructor.Args 0).Type}}
6651 {{- end}}
6752 {{- end}}
6853
69- {{- if gt (len .Func.Args) 1}}
70- args args
54+ {{- $commentPrinted := false }}
55+ {{- range .Func.Args}}
56+ {{- if .Name}}
57+ {{- if not $commentPrinted}}
58+ // Named input parameters for target function.
59+ {{- $commentPrinted = true }}
60+ {{- end}}
61+ {{.Name}} {{.Type}}
7162 {{- end}}
72- {{- if eq (len .Func.Args) 1}}
73- arg {{(index .Func.Args 0).Type}}
7463 {{- end}}
64+
7565 {{- range $index, $res := .Func.Results}}
7666 {{- if eq $res.Name "gotErr"}}
7767 wantErr bool
@@ -96,7 +86,12 @@ func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
9686 {{- .Receiver.Constructor.Name}}
9787
9888 {{- /* Constructor input parameters. */ -}}
99- ({{- if eq (len .Receiver.Constructor.Args) 1}}tt.constructorArg{{end}}{{if gt (len .Func.Args) 1}}{{fieldNames .Receiver.Constructor.Args "tt.constructorArgs."}}{{end}})
89+ (
90+ {{- range $index, $arg := .Receiver.Constructor.Args}}
91+ {{- if ne $index 0}}, {{end}}
92+ {{- if .Name}}tt.{{.Name}}{{else}}{{.Value}}{{end}}
93+ {{- end -}}
94+ )
10095
10196 {{- /* Handles the error return from constructor. */}}
10297 {{- $last := last .Receiver.Constructor.Results}}
@@ -123,7 +118,12 @@ func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
123118 {{- end}}{{.Func.Name}}
124119
125120 {{- /* Input parameters. */ -}}
126- ({{- if eq (len .Func.Args) 1}}tt.arg{{end}}{{if gt (len .Func.Args) 1}}{{fieldNames .Func.Args "tt.args."}}{{end}})
121+ (
122+ {{- range $index, $arg := .Func.Args}}
123+ {{- if ne $index 0}}, {{end}}
124+ {{- if .Name}}tt.{{.Name}}{{else}}{{.Value}}{{end}}
125+ {{- end -}}
126+ )
127127
128128 {{- /* Handles the returned error before the rest of return value. */}}
129129 {{- $last := last .Func.Results}}
@@ -155,8 +155,12 @@ func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
155155}
156156`
157157
158+ // Name is the name of the field this input parameter should reference.
159+ // Value is the expression this input parameter should accept.
160+ //
161+ // Exactly one of Name or Value must be set.
158162type field struct {
159- Name , Type string
163+ Name , Type , Value string
160164}
161165
162166type function struct {
@@ -191,6 +195,9 @@ type testInfo struct {
191195var testTmpl = template .Must (template .New ("test" ).Funcs (template.FuncMap {
192196 "add" : func (a , b int ) int { return a + b },
193197 "last" : func (slice []field ) field {
198+ if len (slice ) == 0 {
199+ return field {}
200+ }
194201 return slice [len (slice )- 1 ]
195202 },
196203 "fieldNames" : func (fields []field , qualifier string ) (res string ) {
@@ -450,36 +457,32 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
450457
451458 errorType := types .Universe .Lookup ("error" ).Type ()
452459
453- // TODO(hxjiang): if input parameter is not named (meaning it's not used),
454- // pass the zero value to the function call.
455- // TODO(hxjiang): if the input parameter is named, define the field by using
456- // the parameter's name instead of in%d.
457460 // TODO(hxjiang): handle special case for ctx.Context input.
458- for index := range sig .Params ().Len () {
459- var name string
460- if index == 0 {
461- name = "in"
461+ for i := range sig .Params ().Len () {
462+ param := sig .Params ().At (i )
463+ name , typ := param .Name (), param .Type ()
464+ f := field {Type : types .TypeString (typ , qf )}
465+ if name == "" || name == "_" {
466+ f .Value = typesinternal .ZeroString (typ , qf )
462467 } else {
463- name = fmt . Sprintf ( "in%d" , index + 1 )
468+ f . Name = name
464469 }
465- data .Func .Args = append (data .Func .Args , field {
466- Name : name ,
467- Type : types .TypeString (sig .Params ().At (index ).Type (), qf ),
468- })
470+ data .Func .Args = append (data .Func .Args , f )
469471 }
470472
471- for index := range sig .Results ().Len () {
473+ for i := range sig .Results ().Len () {
474+ typ := sig .Results ().At (i ).Type ()
472475 var name string
473- if index == sig .Results ().Len ()- 1 && types .Identical (sig . Results (). At ( index ). Type () , errorType ) {
476+ if i == sig .Results ().Len ()- 1 && types .Identical (typ , errorType ) {
474477 name = "gotErr"
475- } else if index == 0 {
478+ } else if i == 0 {
476479 name = "got"
477480 } else {
478- name = fmt .Sprintf ("got%d" , index + 1 )
481+ name = fmt .Sprintf ("got%d" , i + 1 )
479482 }
480483 data .Func .Results = append (data .Func .Results , field {
481484 Name : name ,
482- Type : types .TypeString (sig . Results (). At ( index ). Type () , qf ),
485+ Type : types .TypeString (typ , qf ),
483486 })
484487 }
485488
@@ -587,25 +590,25 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
587590
588591 if constructor != nil {
589592 data .Receiver .Constructor = & function {Name : constructor .Name ()}
590- for index := range constructor .Signature ().Params ().Len () {
591- var name string
592- if index == 0 {
593- name = "in"
593+ for i := range constructor .Signature ().Params ().Len () {
594+ param := constructor .Signature ().Params ().At (i )
595+ name , typ := param .Name (), param .Type ()
596+ f := field {Type : types .TypeString (typ , qf )}
597+ if name == "" || name == "_" {
598+ f .Value = typesinternal .ZeroString (typ , qf )
594599 } else {
595- name = fmt . Sprintf ( "in%d" , index + 1 )
600+ f . Name = name
596601 }
597- data .Receiver .Constructor .Args = append (data .Receiver .Constructor .Args , field {
598- Name : name ,
599- Type : types .TypeString (constructor .Signature ().Params ().At (index ).Type (), qf ),
600- })
602+ data .Receiver .Constructor .Args = append (data .Receiver .Constructor .Args , f )
601603 }
602- for index := range constructor .Signature ().Results ().Len () {
604+ for i := range constructor .Signature ().Results ().Len () {
605+ typ := constructor .Signature ().Results ().At (i ).Type ()
603606 var name string
604- if index == 0 {
607+ if i == 0 {
605608 // The first return value must be of type T, *T, or a type whose named
606609 // type is the same as named type of T.
607610 name = varName
608- } else if index == constructor .Signature ().Results ().Len ()- 1 && types .Identical (constructor . Signature (). Results (). At ( index ). Type () , errorType ) {
611+ } else if i == constructor .Signature ().Results ().Len ()- 1 && types .Identical (typ , errorType ) {
609612 name = "err"
610613 } else {
611614 // Drop any return values beyond the first and the last.
@@ -614,12 +617,48 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
614617 }
615618 data .Receiver .Constructor .Results = append (data .Receiver .Constructor .Results , field {
616619 Name : name ,
617- Type : types .TypeString (constructor . Signature (). Results (). At ( index ). Type () , qf ),
620+ Type : types .TypeString (typ , qf ),
618621 })
619622 }
620623 }
621624 }
622625
626+ // Resolves duplicate parameter names between the function and its
627+ // receiver's constructor. It adds prefix to the constructor's parameters
628+ // until no conflicts remain.
629+ if data .Receiver != nil && data .Receiver .Constructor != nil {
630+ seen := map [string ]bool {}
631+ for _ , f := range data .Func .Args {
632+ if f .Name == "" {
633+ continue
634+ }
635+ seen [f .Name ] = true
636+ }
637+
638+ // "" for no change, "c" for constructor, "i" for input.
639+ for _ , prefix := range []string {"" , "c" , "c_" , "i" , "i_" } {
640+ conflict := false
641+ for _ , f := range data .Receiver .Constructor .Args {
642+ if f .Name == "" {
643+ continue
644+ }
645+ if seen [prefix + f .Name ] {
646+ conflict = true
647+ break
648+ }
649+ }
650+ if ! conflict {
651+ for i , f := range data .Receiver .Constructor .Args {
652+ if f .Name == "" {
653+ continue
654+ }
655+ data .Receiver .Constructor .Args [i ].Name = prefix + data .Receiver .Constructor .Args [i ].Name
656+ }
657+ break
658+ }
659+ }
660+ }
661+
623662 // Compute edits to update imports.
624663 //
625664 // If we're adding to an existing test file, we need to adjust existing
0 commit comments