@@ -229,25 +229,16 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
229229 return nil , fmt .Errorf ("package has type errors: %v" , errors [0 ])
230230 }
231231
232- type packageInfo struct {
233- name string
234- renamed bool
235- }
236-
232+ // All three maps map the path of an imported package to
233+ // the local name if explicit or "" otherwise.
237234 var (
238- // fileImports is a map contains all the path imported in the original
239- // file foo.go.
240- fileImports map [string ]packageInfo
241- // testImports is a map contains all the path already imported in test
242- // file foo_test.go.
243- testImports map [string ]packageInfo
244- // extraImportsis a map from package path to local package name that
245- // need to be imported for the test function.
246- extraImports = make (map [string ]packageInfo )
235+ fileImports map [string ]string // imports in foo.go file
236+ testImports map [string ]string // imports in foo_test.go file
237+ extraImports = make (map [string ]string ) // imports to add to test file
247238 )
248239
249- var collectImports = func (file * ast.File ) (map [string ]packageInfo , error ) {
250- imps := make (map [string ]packageInfo )
240+ var collectImports = func (file * ast.File ) (map [string ]string , error ) {
241+ imps := make (map [string ]string )
251242 for _ , spec := range file .Imports {
252243 // TODO(hxjiang): support dot imports.
253244 if spec .Name != nil && spec .Name .Name == "." {
@@ -261,29 +252,9 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
261252 if spec .Name .Name == "_" {
262253 continue
263254 }
264- imps [path ] = packageInfo { spec .Name .Name , true }
255+ imps [path ] = spec .Name .Name
265256 } else {
266- // The package name might differ from the base of its import
267- // path. For example, "/path/to/package/foo" could declare a
268- // package named "bar". Look up the target package ensures the
269- // accurate package name reference.
270- //
271- // While it's best practice to rename imported packages when
272- // their name differs from the base path (e.g.,
273- // "import bar \"path/to/package/foo\""), this is not mandatory.
274- id := pkg .Metadata ().DepsByImpPath [metadata .ImportPath (path )]
275- if metadata .IsCommandLineArguments (id ) {
276- return nil , fmt .Errorf ("can not import command-line-arguments package" )
277- }
278- if id == "" { // guess upon missing.
279- imps [path ] = packageInfo {imports .ImportPathToAssumedName (path ), false }
280- } else {
281- fromPkg , ok := snapshot .MetadataGraph ().Packages [id ]
282- if ! ok {
283- return nil , fmt .Errorf ("package id %v does not exist" , id )
284- }
285- imps [path ] = packageInfo {string (fromPkg .Name ), false }
286- }
257+ imps [path ] = ""
287258 }
288259 }
289260 return imps , nil
@@ -454,25 +425,27 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
454425 // If the package is not imported in test file foo_test.go, it is added to
455426 // extraImports map.
456427 qf := func (p * types.Package ) string {
457- // When generating test in x packages, any type/function defined in the same
458- // x package can emit package name.
428+ // References from an in-package test should not be qualified.
459429 if ! xtest && p == pkg .Types () {
460430 return ""
461431 }
462432 // Prefer using the package name if already defined in foo_test.go
463433 if local , ok := testImports [p .Path ()]; ok {
464- return local .name
434+ if local != "" {
435+ return local
436+ } else {
437+ return p .Name ()
438+ }
465439 }
466440 // TODO(hxjiang): we should consult the scope of the test package to
467441 // ensure these new imports do not shadow any package-level names.
468- // If not already imported by foo_test.go, consult the foo.go import map.
469- if local , ok := fileImports [p .Path ()]; ok {
470- // The package that contains this type need to be added to the import
471- // list in foo_test.go.
442+ // Prefer the local import name (if any) used in the package under test.
443+ if local , ok := fileImports [p .Path ()]; ok && local != "" {
472444 extraImports [p .Path ()] = local
473- return local . name
445+ return local
474446 }
475- extraImports [p .Path ()] = packageInfo {name : p .Name ()}
447+ // Fall back to the package name since there is no renaming.
448+ extraImports [p .Path ()] = ""
476449 return p .Name ()
477450 }
478451
@@ -728,11 +701,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
728701 // imports. Otherwise, we can simply write out the imports to the new file.
729702 if testPGF != nil {
730703 var importFixes []* imports.ImportFix
731- for path , info := range extraImports {
732- name := ""
733- if info .renamed {
734- name = info .name
735- }
704+ for path , name := range extraImports {
736705 importFixes = append (importFixes , & imports.ImportFix {
737706 StmtInfo : imports.ImportInfo {
738707 ImportPath : path ,
@@ -750,9 +719,9 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
750719 var importsBuffer bytes.Buffer
751720 if len (extraImports ) == 1 {
752721 importsBuffer .WriteString ("\n import " )
753- for path , info := range extraImports {
754- if info . renamed {
755- importsBuffer .WriteString (info . name + " " )
722+ for path , name := range extraImports {
723+ if name != "" {
724+ importsBuffer .WriteString (name + " " )
756725 }
757726 importsBuffer .WriteString (fmt .Sprintf ("\" %s\" \n " , path ))
758727 }
@@ -766,8 +735,8 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
766735 sort .Strings (paths )
767736 for _ , path := range paths {
768737 importsBuffer .WriteString ("\n \t " )
769- if extraImports [path ]. renamed {
770- importsBuffer .WriteString (extraImports [ path ]. name + " " )
738+ if name := extraImports [path ]; name != "" {
739+ importsBuffer .WriteString (name + " " )
771740 }
772741 importsBuffer .WriteString (fmt .Sprintf ("\" %s\" " , path ))
773742 }
0 commit comments