@@ -17,19 +17,23 @@ import (
1717 "html/template"
1818 "os"
1919 "path/filepath"
20+ "sort"
2021 "strconv"
2122 "strings"
2223 "unicode"
2324
2425 "golang.org/x/tools/go/ast/astutil"
2526 "golang.org/x/tools/gopls/internal/cache"
27+ "golang.org/x/tools/gopls/internal/cache/metadata"
2628 "golang.org/x/tools/gopls/internal/cache/parsego"
2729 "golang.org/x/tools/gopls/internal/protocol"
2830 goplsastutil "golang.org/x/tools/gopls/internal/util/astutil"
31+ "golang.org/x/tools/internal/imports"
2932 "golang.org/x/tools/internal/typesinternal"
3033)
3134
32- const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
35+ const testTmplString = `
36+ func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
3337 {{- /* Constructor input parameters struct declaration. */}}
3438 {{- if and .Receiver .Receiver.Constructor}}
3539 {{- if gt (len .Receiver.Constructor.Args) 1}}
@@ -83,7 +87,7 @@ const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
8387
8488 {{- /* Loop over all the test cases. */}}
8589 for _, tt := range tests {
86- t.Run(tt.name, func(t *testing .T) {
90+ t.Run(tt.name, func(t *{{.TestingPackageName}} .T) {
8791 {{- /* Constructor or empty initialization. */}}
8892 {{- if .Receiver}}
8993 {{- if .Receiver.Constructor}}
@@ -170,6 +174,10 @@ type receiver struct {
170174}
171175
172176type testInfo struct {
177+ // TestingPackageName is the package name should be used when referencing
178+ // package "testing"
179+ TestingPackageName string
180+ // PackageName is the package name the target function/method is delcared from.
173181 PackageName string
174182 TestFuncName string
175183 // Func holds information about the function or method being tested.
@@ -202,37 +210,79 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
202210 return nil , err
203211 }
204212
213+ if metadata .IsCommandLineArguments (pkg .Metadata ().ID ) {
214+ return nil , fmt .Errorf ("current file in command-line-arguments package" )
215+ }
216+
205217 if errors := pkg .ParseErrors (); len (errors ) > 0 {
206218 return nil , fmt .Errorf ("package has parse errors: %v" , errors [0 ])
207219 }
208220 if errors := pkg .TypeErrors (); len (errors ) > 0 {
209221 return nil , fmt .Errorf ("package has type errors: %v" , errors [0 ])
210222 }
211223
212- // imports is a map from package path to local package name.
213- var imports = make (map [string ]string )
224+ type packageInfo struct {
225+ name string
226+ renamed bool
227+ }
228+
229+ var (
230+ // fileImports is a map contains all the path imported in the original
231+ // file foo.go.
232+ fileImports map [string ]packageInfo
233+ // testImports is a map contains all the path already imported in test
234+ // file foo_test.go.
235+ testImports map [string ]packageInfo
236+ // extraImportsis a map from package path to local package name that
237+ // need to be imported for the test function.
238+ extraImports = make (map [string ]packageInfo )
239+ )
214240
215- var collectImports = func (file * ast.File ) error {
241+ var collectImports = func (file * ast.File ) (map [string ]packageInfo , error ) {
242+ imps := make (map [string ]packageInfo )
216243 for _ , spec := range file .Imports {
217244 // TODO(hxjiang): support dot imports.
218245 if spec .Name != nil && spec .Name .Name == "." {
219- return fmt .Errorf ("\" add a test for FUNC \" does not support files containing dot imports" )
246+ return nil , fmt .Errorf ("\" add a test for func \" does not support files containing dot imports" )
220247 }
221248 path , err := strconv .Unquote (spec .Path .Value )
222249 if err != nil {
223- return err
250+ return nil , err
224251 }
225- if spec .Name != nil && spec .Name .Name != "_" {
226- imports [path ] = spec .Name .Name
252+ if spec .Name != nil {
253+ if spec .Name .Name == "_" {
254+ continue
255+ }
256+ imps [path ] = packageInfo {spec .Name .Name , true }
227257 } else {
228- imports [path ] = filepath .Base (path )
258+ // The package name might differ from the base of its import
259+ // path. For example, "/path/to/package/foo" could declare a
260+ // package named "bar". Look up the target package ensures the
261+ // accurate package name reference.
262+ //
263+ // While it's best practice to rename imported packages when
264+ // their name differs from the base path (e.g.,
265+ // "import bar \"path/to/package/foo\""), this is not mandatory.
266+ id := pkg .Metadata ().DepsByImpPath [metadata .ImportPath (path )]
267+ if metadata .IsCommandLineArguments (id ) {
268+ return nil , fmt .Errorf ("can not import command-line-arguments package" )
269+ }
270+ if id == "" { // guess upon missing.
271+ imps [path ] = packageInfo {imports .ImportPathToAssumedName (path ), false }
272+ } else {
273+ fromPkg , ok := snapshot .MetadataGraph ().Packages [id ]
274+ if ! ok {
275+ return nil , fmt .Errorf ("package id %v does not exist" , id )
276+ }
277+ imps [path ] = packageInfo {string (fromPkg .Name ), false }
278+ }
229279 }
230280 }
231- return nil
281+ return imps , nil
232282 }
233283
234284 // Collect all the imports from the x.go, keep track of the local package name.
235- if err : = collectImports (pgf .File ); err != nil {
285+ if fileImports , err = collectImports (pgf .File ); err != nil {
236286 return nil , err
237287 }
238288
@@ -259,7 +309,8 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
259309 xtest = true
260310 )
261311
262- if testPGF , err := snapshot .ParseGo (ctx , testFH , parsego .Header ); err != nil {
312+ testPGF , err := snapshot .ParseGo (ctx , testFH , parsego .Header )
313+ if err != nil {
263314 if ! errors .Is (err , os .ErrNotExist ) {
264315 return nil , err
265316 }
@@ -288,8 +339,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
288339 header .WriteString ("\n \n " )
289340 }
290341 }
291- // One empty line between package decl and rest of the file.
292- fmt .Fprintf (& header , "package %s_test\n \n " , pkg .Types ().Name ())
342+ fmt .Fprintf (& header , "package %s_test\n " , pkg .Types ().Name ())
293343
294344 // Write the copyright and package decl to the beginning of the file.
295345 edits = append (edits , protocol.TextEdit {
@@ -314,29 +364,41 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
314364 return nil , err
315365 }
316366
317- // Collect all the imports from the x_test.go, overwrite the local pakcage
318- // name collected from x.go.
319- if err := collectImports (testPGF .File ); err != nil {
367+ // Collect all the imports from the foo_test.go.
368+ if testImports , err = collectImports (testPGF .File ); err != nil {
320369 return nil , err
321370 }
322371 }
323372
324- // qf qualifier returns the local package name need to use in x_test.go by
325- // consulting the consolidated imports map.
373+ // qf qualifier determines the correct package name to use for a type in
374+ // foo_test.go. It does this by:
375+ // - Consult imports map from test file foo_test.go.
376+ // - If not found, consult imports map from original file foo.go.
377+ // If the package is not imported in test file foo_test.go, it is added to
378+ // extraImports map.
326379 qf := func (p * types.Package ) string {
327380 // When generating test in x packages, any type/function defined in the same
328381 // x package can emit package name.
329382 if ! xtest && p == pkg .Types () {
330383 return ""
331384 }
332- if local , ok := imports [p .Path ()]; ok {
333- return local
385+ // Prefer using the package name if already defined in foo_test.go
386+ if local , ok := testImports [p .Path ()]; ok {
387+ return local .name
334388 }
389+ // TODO(hxjiang): we should consult the scope of the test package to
390+ // ensure these new imports do not shadow any package-level names.
391+ // If not already imported by foo_test.go, consult the foo.go import map.
392+ if local , ok := fileImports [p .Path ()]; ok {
393+ // The package that contains this type need to be added to the import
394+ // list in foo_test.go.
395+ extraImports [p .Path ()] = local
396+ return local .name
397+ }
398+ extraImports [p .Path ()] = packageInfo {name : p .Name ()}
335399 return p .Name ()
336400 }
337401
338- // TODO(hxjiang): modify existing imports or add new imports.
339-
340402 start , end , err := pgf .RangePos (loc .Range )
341403 if err != nil {
342404 return nil , err
@@ -378,8 +440,9 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
378440 }
379441
380442 data := testInfo {
381- PackageName : qf (pkg .Types ()),
382- TestFuncName : testName ,
443+ TestingPackageName : qf (types .NewPackage ("testing" , "testing" )),
444+ PackageName : qf (pkg .Types ()),
445+ TestFuncName : testName ,
383446 Func : function {
384447 Name : fn .Name (),
385448 },
@@ -557,15 +620,73 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
557620 }
558621 }
559622
623+ // Compute edits to update imports.
624+ //
625+ // If we're adding to an existing test file, we need to adjust existing
626+ // imports. Otherwise, we can simply write out the imports to the new file.
627+ if testPGF != nil {
628+ var importFixes []* imports.ImportFix
629+ for path , info := range extraImports {
630+ name := ""
631+ if info .renamed {
632+ name = info .name
633+ }
634+ importFixes = append (importFixes , & imports.ImportFix {
635+ StmtInfo : imports.ImportInfo {
636+ ImportPath : path ,
637+ Name : name ,
638+ },
639+ FixType : imports .AddImport ,
640+ })
641+ }
642+ importEdits , err := ComputeImportFixEdits (snapshot .Options ().Local , testPGF .Src , importFixes ... )
643+ if err != nil {
644+ return nil , fmt .Errorf ("could not compute the import fix edits: %w" , err )
645+ }
646+ edits = append (edits , importEdits ... )
647+ } else {
648+ var importsBuffer bytes.Buffer
649+ if len (extraImports ) == 1 {
650+ importsBuffer .WriteString ("\n import " )
651+ for path , info := range extraImports {
652+ if info .renamed {
653+ importsBuffer .WriteString (info .name + " " )
654+ }
655+ importsBuffer .WriteString (fmt .Sprintf ("\" %s\" \n " , path ))
656+ }
657+ } else {
658+ importsBuffer .WriteString ("\n import(" )
659+ // Loop over the map in sorted order ensures deterministic outcome.
660+ paths := make ([]string , 0 , len (extraImports ))
661+ for key := range extraImports {
662+ paths = append (paths , key )
663+ }
664+ sort .Strings (paths )
665+ for _ , path := range paths {
666+ importsBuffer .WriteString ("\n \t " )
667+ if extraImports [path ].renamed {
668+ importsBuffer .WriteString (extraImports [path ].name + " " )
669+ }
670+ importsBuffer .WriteString (fmt .Sprintf ("\" %s\" " , path ))
671+ }
672+ importsBuffer .WriteString ("\n )\n " )
673+ }
674+ edits = append (edits , protocol.TextEdit {
675+ Range : protocol.Range {},
676+ NewText : importsBuffer .String (),
677+ })
678+ }
679+
560680 var test bytes.Buffer
561681 if err := testTmpl .Execute (& test , data ); err != nil {
562682 return nil , err
563683 }
564684
565- edits = append (edits , protocol.TextEdit {
566- Range : eofRange ,
567- NewText : test .String (),
568- })
685+ edits = append (edits ,
686+ protocol.TextEdit {
687+ Range : eofRange ,
688+ NewText : test .String (),
689+ })
569690
570691 return append (changes , protocol .DocumentChangeEdit (testFH , edits )), nil
571692}
0 commit comments