@@ -11,14 +11,19 @@ import (
1111 "context"
1212 "errors"
1313 "fmt"
14+ "go/ast"
1415 "go/token"
16+ "go/types"
1517 "os"
1618 "path/filepath"
1719 "strings"
1820
21+ "golang.org/x/tools/go/ast/astutil"
1922 "golang.org/x/tools/gopls/internal/cache"
2023 "golang.org/x/tools/gopls/internal/cache/parsego"
2124 "golang.org/x/tools/gopls/internal/protocol"
25+ goplsastutil "golang.org/x/tools/gopls/internal/util/astutil"
26+ "golang.org/x/tools/internal/typesinternal"
2227)
2328
2429// AddTestForFunc adds a test for the function enclosing the given input range.
@@ -29,6 +34,13 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
2934 return nil , err
3035 }
3136
37+ if errors := pkg .ParseErrors (); len (errors ) > 0 {
38+ return nil , fmt .Errorf ("package has parse errors: %v" , errors [0 ])
39+ }
40+ if errors := pkg .TypeErrors (); len (errors ) > 0 {
41+ return nil , fmt .Errorf ("package has type errors: %v" , errors [0 ])
42+ }
43+
3244 testBase := strings .TrimSuffix (filepath .Base (loc .URI .Path ()), ".go" ) + "_test.go"
3345 goTestFileURI := protocol .URIFromPath (filepath .Join (loc .URI .Dir ().Path (), testBase ))
3446
@@ -41,32 +53,37 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
4153 // exist.
4254
4355 var (
56+ eofRange protocol.Range // empty selection at end of new file
4457 // edits contains all the text edits to be applied to the test file.
4558 edits []protocol.TextEdit
46- // header is the buffer containing the text edit to the beginning of the file.
47- header bytes.Buffer
59+ // xtest indicates whether the test file use package x or x_test.
60+ // TODO(hxjiang): For now, we try to interpret the user's intention by
61+ // reading the foo_test.go's package name. Instead, we can discuss the option
62+ // to interpret the user's intention by which function they are selecting.
63+ // Have one file for x_test package testing, one file for x package testing.
64+ xtest = true
4865 )
4966
50- testPgf , err := snapshot .ParseGo (ctx , testFH , parsego .Header )
51- if err != nil {
67+ if testPGF , err := snapshot .ParseGo (ctx , testFH , parsego .Header ); err != nil {
5268 if ! errors .Is (err , os .ErrNotExist ) {
5369 return nil , err
5470 }
55-
5671 changes = append (changes , protocol .DocumentChangeCreate (goTestFileURI ))
5772
58- // If this test file was created by the gopls, add a copyright header based
59- // on the originating file.
73+ // header is the buffer containing the text to add to the beginning of the file.
74+ var header bytes.Buffer
75+
76+ // If this test file was created by the gopls, add a copyright header and
77+ // package decl based on the originating file.
6078 // Search for something that looks like a copyright header, to replicate
6179 // in the new file.
62- // TODO(hxjiang): should we refine this heuristic, for example by checking for
63- // the word 'copyright'?
6480 if groups := pgf .File .Comments ; len (groups ) > 0 {
6581 // Copyright should appear before package decl and must be the first
6682 // comment group.
6783 // Avoid copying any other comment like package doc or directive comment.
6884 if c := groups [0 ]; c .Pos () < pgf .File .Package && c != pgf .File .Doc &&
69- ! isDirective (groups [0 ].List [0 ].Text ) {
85+ ! isDirective (c .List [0 ].Text ) &&
86+ strings .Contains (strings .ToLower (c .List [0 ].Text ), "copyright" ) {
7087 start , end , err := pgf .NodeOffsets (c )
7188 if err != nil {
7289 return nil , err
@@ -76,42 +93,117 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
7693 header .WriteString ("\n \n " )
7794 }
7895 }
79- }
80-
81- // If the test file does not have package decl, use the originating file to
82- // determine a package decl for the new file. Prefer xtest package.s
83- if testPgf == nil || testPgf .File == nil || testPgf .File .Package == token .NoPos {
8496 // One empty line between package decl and rest of the file.
8597 fmt .Fprintf (& header , "package %s_test\n \n " , pkg .Types ().Name ())
86- }
8798
88- // Write the copyright and package decl to the beginning of the file.
89- if text := header .String (); len (text ) != 0 {
99+ // Write the copyright and package decl to the beginning of the file.
90100 edits = append (edits , protocol.TextEdit {
91101 Range : protocol.Range {},
92- NewText : text ,
102+ NewText : header . String () ,
93103 })
104+ } else { // existing _test.go file.
105+ if testPGF .File .Name == nil || testPGF .File .Name .NamePos == token .NoPos {
106+ return nil , fmt .Errorf ("missing package declaration" )
107+ }
108+ switch testPGF .File .Name .Name {
109+ case pgf .File .Name .Name :
110+ xtest = false
111+ case pgf .File .Name .Name + "_test" :
112+ xtest = true
113+ default :
114+ return nil , fmt .Errorf ("invalid package declaration %q in test file %q" , testPGF .File .Name , testPGF )
115+ }
116+
117+ eofRange , err = testPGF .PosRange (testPGF .File .FileEnd , testPGF .File .FileEnd )
118+ if err != nil {
119+ return nil , err
120+ }
94121 }
95122
96- // TODO(hxjiang): reject if the function/method is unexported.
97123 // TODO(hxjiang): modify existing imports or add new imports.
98124
99- // If the parse go file is missing, the fileEnd is the file start (zero value).
100- fileEnd := protocol.Range {}
101- if testPgf != nil {
102- fileEnd , err = testPgf .PosRange (testPgf .File .FileEnd , testPgf .File .FileEnd )
103- if err != nil {
104- return nil , err
125+ start , end , err := pgf .RangePos (loc .Range )
126+ if err != nil {
127+ return nil , err
128+ }
129+
130+ path , _ := astutil .PathEnclosingInterval (pgf .File , start , end )
131+ if len (path ) < 2 {
132+ return nil , fmt .Errorf ("no enclosing function" )
133+ }
134+
135+ decl , ok := path [len (path )- 2 ].(* ast.FuncDecl )
136+ if ! ok {
137+ return nil , fmt .Errorf ("no enclosing function" )
138+ }
139+
140+ fn := pkg .TypesInfo ().Defs [decl .Name ].(* types.Func )
141+
142+ if xtest {
143+ // Reject if function/method is unexported.
144+ if ! fn .Exported () {
145+ return nil , fmt .Errorf ("cannot add test of unexported function %s to external test package %s_test" , decl .Name , pgf .File .Name )
105146 }
147+
148+ // Reject if receiver is unexported.
149+ if fn .Signature ().Recv () != nil {
150+ if _ , ident , _ := goplsastutil .UnpackRecv (decl .Recv .List [0 ].Type ); ! ident .IsExported () {
151+ return nil , fmt .Errorf ("cannot add external test for method %s.%s as receiver type is not exported" , ident .Name , decl .Name )
152+ }
153+ }
154+
155+ // TODO(hxjiang): reject if the any input parameter type is unexported.
156+ // TODO(hxjiang): reject if any return value type is unexported. Explore
157+ // the option to drop the return value if the type is unexported.
106158 }
107159
108- // test is the buffer containing the text edit to the test function.
109- var test bytes.Buffer
110- // TODO(hxjiang): replace test foo function with table-driven test.
111- test .WriteString ("\n func TestFoo(*testing.T) {}" )
160+ testName , err := testName (fn )
161+ if err != nil {
162+ return nil , err
163+ }
164+ // TODO(hxjiang): replace test function with table-driven test.
112165 edits = append (edits , protocol.TextEdit {
113- Range : fileEnd ,
114- NewText : test .String (),
166+ Range : eofRange ,
167+ NewText : fmt .Sprintf (`
168+ func %s(*testing.T) {
169+ // TODO: implement test
170+ }
171+ ` , testName ),
115172 })
116173 return append (changes , protocol .DocumentChangeEdit (testFH , edits )), nil
117174}
175+
176+ // testName returns the name of the function to use for the new function that
177+ // tests fn.
178+ // Returns empty string if the fn is ill typed or nil.
179+ func testName (fn * types.Func ) (string , error ) {
180+ if fn == nil {
181+ return "" , fmt .Errorf ("input nil function" )
182+ }
183+ testName := "Test"
184+ if recv := fn .Signature ().Recv (); recv != nil { // method declaration.
185+ // Retrieve the unpointered receiver type to ensure the test name is based
186+ // on the topmost alias or named type, not the alias' RHS type (potentially
187+ // unexported) type.
188+ // For example:
189+ // type Foo = foo // Foo is an exported alias for the unexported type foo
190+ recvType := recv .Type ()
191+ if ptr , ok := recv .Type ().(* types.Pointer ); ok {
192+ recvType = ptr .Elem ()
193+ }
194+
195+ t , ok := recvType .(typesinternal.NamedOrAlias )
196+ if ! ok {
197+ return "" , fmt .Errorf ("receiver type is not named type or alias type" )
198+ }
199+
200+ if ! t .Obj ().Exported () {
201+ testName += "_"
202+ }
203+
204+ testName += t .Obj ().Name () + "_"
205+ } else if ! fn .Exported () { // unexported function declaration.
206+ testName += "_"
207+ }
208+ return testName + fn .Name (), nil
209+ }
0 commit comments