@@ -14,6 +14,7 @@ import (
1414 "go/ast"
1515 "go/token"
1616 "go/types"
17+ "html/template"
1718 "os"
1819 "path/filepath"
1920 "strings"
@@ -26,6 +27,79 @@ import (
2627 "golang.org/x/tools/internal/typesinternal"
2728)
2829
30+ const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
31+ {{- /* Functions/methods input parameters struct declaration. */}}
32+ {{- if gt (len .Args) 1}}
33+ type args struct {
34+ {{- range .Args}}
35+ {{.Name}} {{.Type}}
36+ {{- end}}
37+ }
38+ {{- end}}
39+ {{- /* Test cases struct declaration and empty initialization. */}}
40+ tests := []struct {
41+ name string // description of this test case
42+ {{- if gt (len .Args) 1}}
43+ args args
44+ {{- end}}
45+ {{- if eq (len .Args) 1}}
46+ arg {{(index .Args 0).Type}}
47+ {{- end}}
48+ {{- range $index, $res := .Results}}
49+ {{if eq $index 0}}want{{else}}want{{add $index 1}}{{end}} {{$res.Type}}
50+ {{- /* TODO(hxjiang): check whether the last return type is error and handle it using field "wantErr". */}}
51+ {{- end}}
52+ }{
53+ // TODO: Add test cases.
54+ }
55+ {{- /* Loop over all the test cases. */}}
56+ for _, tt := range tests {
57+ {{/* Got variables. */}}
58+ {{- if .Results}}{{fieldNames .Results ""}} := {{end}}
59+
60+ {{- /* Call expression. In xtest package test, call function by PACKAGE.FUNC. */}}
61+ {{- /* TODO(hxjiang): consider any renaming in existing xtest package imports. E.g. import renamedfoo "foo". */}}
62+ {{- /* TODO(hxjiang): support add test for methods by calling the right constructor. */}}
63+ {{- if .PackageName}}{{.PackageName}}.{{end}}{{.FuncName}}
64+
65+ {{- /* Input parameters. */ -}}
66+ ({{if eq (len .Args) 1}}tt.arg{{end}}{{if gt (len .Args) 1}}{{fieldNames .Args "tt.args."}}{{end}})
67+
68+ {{- if .Results}}
69+ // TODO: update the condition below to compare got with tt.want.
70+ {{- range $index, $res := .Results}}
71+ if true {
72+ t.Errorf("%s: {{$.FuncName}}() = %v, want %v", tt.name, {{.Name}}, tt.{{if eq $index 0}}want{{else}}want{{add $index 1}}{{end}})
73+ }
74+ {{- end}}
75+ {{- end}}
76+ }
77+ }
78+ `
79+
80+ type field struct {
81+ Name , Type string
82+ }
83+
84+ type testInfo struct {
85+ PackageName string
86+ FuncName string
87+ TestFuncName string
88+ Args []field
89+ Results []field
90+ }
91+
92+ var testTmpl = template .Must (template .New ("test" ).Funcs (template.FuncMap {
93+ "add" : func (a , b int ) int { return a + b },
94+ "fieldNames" : func (fields []field , qualifier string ) (res string ) {
95+ var names []string
96+ for _ , f := range fields {
97+ names = append (names , qualifier + f .Name )
98+ }
99+ return strings .Join (names , ", " )
100+ },
101+ }).Parse (testTmplString ))
102+
29103// AddTestForFunc adds a test for the function enclosing the given input range.
30104// It creates a _test.go file if one does not already exist.
31105func AddTestForFunc (ctx context.Context , snapshot * cache.Snapshot , loc protocol.Location ) (changes []protocol.DocumentChange , _ error ) {
@@ -138,6 +212,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
138212 }
139213
140214 fn := pkg .TypesInfo ().Defs [decl .Name ].(* types.Func )
215+ sig := fn .Signature ()
141216
142217 if xtest {
143218 // Reject if function/method is unexported.
@@ -146,30 +221,77 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
146221 }
147222
148223 // Reject if receiver is unexported.
149- if fn . Signature () .Recv () != nil {
224+ if sig .Recv () != nil {
150225 if _ , ident , _ := goplsastutil .UnpackRecv (decl .Recv .List [0 ].Type ); ! ident .IsExported () {
151226 return nil , fmt .Errorf ("cannot add external test for method %s.%s as receiver type is not exported" , ident .Name , decl .Name )
152227 }
153228 }
154-
155229 // TODO(hxjiang): reject if the any input parameter type is unexported.
156230 // TODO(hxjiang): reject if any return value type is unexported. Explore
157231 // the option to drop the return value if the type is unexported.
158232 }
159233
234+ // TODO(hxjiang): qualifier should consolidate existing imports from x
235+ // package and existing x_test package. The existing x_test package imports
236+ // should overwrite x package imports.
237+ var qf types.Qualifier
238+ if xtest {
239+ qf = (* types .Package ).Name
240+ } else {
241+ qf = typesinternal .NameRelativeTo (pkg .Types ())
242+ }
243+
160244 testName , err := testName (fn )
161245 if err != nil {
162246 return nil , err
163247 }
164- // TODO(hxjiang): replace test function with table-driven test.
248+ data := testInfo {
249+ FuncName : fn .Name (),
250+ TestFuncName : testName ,
251+ }
252+
253+ if sig .Recv () == nil && xtest {
254+ data .PackageName = pkg .Types ().Name ()
255+ }
256+
257+ for i := range sig .Params ().Len () {
258+ if i == 0 {
259+ data .Args = append (data .Args , field {
260+ Name : "in" ,
261+ Type : types .TypeString (sig .Params ().At (i ).Type (), qf ),
262+ })
263+ } else {
264+ data .Args = append (data .Args , field {
265+ Name : fmt .Sprintf ("in%d" , i + 1 ),
266+ Type : types .TypeString (sig .Params ().At (i ).Type (), qf ),
267+ })
268+ }
269+ }
270+
271+ for i := range sig .Results ().Len () {
272+ if i == 0 {
273+ data .Results = append (data .Results , field {
274+ Name : "got" ,
275+ Type : types .TypeString (sig .Results ().At (i ).Type (), qf ),
276+ })
277+ } else {
278+ data .Results = append (data .Results , field {
279+ Name : fmt .Sprintf ("got%d" , i + 1 ),
280+ Type : types .TypeString (sig .Results ().At (i ).Type (), qf ),
281+ })
282+ }
283+ }
284+
285+ var test bytes.Buffer
286+ if err := testTmpl .Execute (& test , data ); err != nil {
287+ return nil , err
288+ }
289+
165290 edits = append (edits , protocol.TextEdit {
166- Range : eofRange ,
167- NewText : fmt .Sprintf (`
168- func %s(*testing.T) {
169- // TODO: implement test
170- }
171- ` , testName ),
291+ Range : eofRange ,
292+ NewText : test .String (),
172293 })
294+
173295 return append (changes , protocol .DocumentChangeEdit (testFH , edits )), nil
174296}
175297
0 commit comments