@@ -212,21 +212,23 @@ var testTmpl = template.Must(template.New("test").Funcs(template.FuncMap{
212212
213213// AddTestForFunc adds a test for the function enclosing the given input range.
214214// It creates a _test.go file if one does not already exist.
215- func AddTestForFunc (ctx context.Context , snapshot * cache.Snapshot , loc protocol.Location ) (changes []protocol.DocumentChange , _ error ) {
215+ // It returns the required text edits and the predicted location of the new test
216+ // function, which is only valid after the edits have been successfully applied.
217+ func AddTestForFunc (ctx context.Context , snapshot * cache.Snapshot , loc protocol.Location ) (changes []protocol.DocumentChange , show * protocol.Location , _ error ) {
216218 pkg , pgf , err := NarrowestPackageForFile (ctx , snapshot , loc .URI )
217219 if err != nil {
218- return nil , err
220+ return nil , nil , err
219221 }
220222
221223 if metadata .IsCommandLineArguments (pkg .Metadata ().ID ) {
222- return nil , fmt .Errorf ("current file in command-line-arguments package" )
224+ return nil , nil , fmt .Errorf ("current file in command-line-arguments package" )
223225 }
224226
225227 if errors := pkg .ParseErrors (); len (errors ) > 0 {
226- return nil , fmt .Errorf ("package has parse errors: %v" , errors [0 ])
228+ return nil , nil , fmt .Errorf ("package has parse errors: %v" , errors [0 ])
227229 }
228230 if errors := pkg .TypeErrors (); len (errors ) > 0 {
229- return nil , fmt .Errorf ("package has type errors: %v" , errors [0 ])
231+ return nil , nil , fmt .Errorf ("package has type errors: %v" , errors [0 ])
230232 }
231233
232234 // All three maps map the path of an imported package to
@@ -262,15 +264,15 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
262264
263265 // Collect all the imports from the x.go, keep track of the local package name.
264266 if fileImports , err = collectImports (pgf .File ); err != nil {
265- return nil , err
267+ return nil , nil , err
266268 }
267269
268270 testBase := strings .TrimSuffix (loc .URI .Base (), ".go" ) + "_test.go"
269271 goTestFileURI := protocol .URIFromPath (filepath .Join (loc .URI .DirPath (), testBase ))
270272
271273 testFH , err := snapshot .ReadFile (ctx , goTestFileURI )
272274 if err != nil {
273- return nil , err
275+ return nil , nil , err
274276 }
275277
276278 // TODO(hxjiang): use a fresh name if the same test function name already
@@ -289,17 +291,17 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
289291
290292 start , end , err := pgf .RangePos (loc .Range )
291293 if err != nil {
292- return nil , err
294+ return nil , nil , err
293295 }
294296
295297 path , _ := astutil .PathEnclosingInterval (pgf .File , start , end )
296298 if len (path ) < 2 {
297- return nil , fmt .Errorf ("no enclosing function" )
299+ return nil , nil , fmt .Errorf ("no enclosing function" )
298300 }
299301
300302 decl , ok := path [len (path )- 2 ].(* ast.FuncDecl )
301303 if ! ok {
302- return nil , fmt .Errorf ("no enclosing function" )
304+ return nil , nil , fmt .Errorf ("no enclosing function" )
303305 }
304306
305307 fn := pkg .TypesInfo ().Defs [decl .Name ].(* types.Func )
@@ -308,7 +310,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
308310 testPGF , err := snapshot .ParseGo (ctx , testFH , parsego .Header )
309311 if err != nil {
310312 if ! errors .Is (err , os .ErrNotExist ) {
311- return nil , err
313+ return nil , nil , err
312314 }
313315 changes = append (changes , protocol .DocumentChangeCreate (goTestFileURI ))
314316
@@ -322,7 +324,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
322324 if c := CopyrightComment (pgf .File ); c != nil {
323325 text , err := pgf .NodeText (c )
324326 if err != nil {
325- return nil , err
327+ return nil , nil , err
326328 }
327329 header .Write (text )
328330 // One empty line between copyright header and following.
@@ -334,7 +336,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
334336 if c := buildConstraintComment (pgf .File ); c != nil {
335337 text , err := pgf .NodeText (c )
336338 if err != nil {
337- return nil , err
339+ return nil , nil , err
338340 }
339341 header .Write (text )
340342 // One empty line between build constraint and following.
@@ -397,25 +399,25 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
397399 } else { // existing _test.go file.
398400 file := testPGF .File
399401 if ! file .Name .NamePos .IsValid () {
400- return nil , fmt .Errorf ("missing package declaration" )
402+ return nil , nil , fmt .Errorf ("missing package declaration" )
401403 }
402404 switch file .Name .Name {
403405 case pgf .File .Name .Name :
404406 xtest = false
405407 case pgf .File .Name .Name + "_test" :
406408 xtest = true
407409 default :
408- return nil , fmt .Errorf ("invalid package declaration %q in test file %q" , file .Name , testPGF )
410+ return nil , nil , fmt .Errorf ("invalid package declaration %q in test file %q" , file .Name , testPGF )
409411 }
410412
411413 eofRange , err = testPGF .PosRange (file .FileEnd , file .FileEnd )
412414 if err != nil {
413- return nil , err
415+ return nil , nil , err
414416 }
415417
416418 // Collect all the imports from the foo_test.go.
417419 if testImports , err = collectImports (file ); err != nil {
418- return nil , err
420+ return nil , nil , err
419421 }
420422 }
421423
@@ -453,13 +455,13 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
453455 if xtest {
454456 // Reject if function/method is unexported.
455457 if ! fn .Exported () {
456- return nil , fmt .Errorf ("cannot add test of unexported function %s to external test package %s_test" , decl .Name , pgf .File .Name )
458+ return nil , nil , fmt .Errorf ("cannot add test of unexported function %s to external test package %s_test" , decl .Name , pgf .File .Name )
457459 }
458460
459461 // Reject if receiver is unexported.
460462 if sig .Recv () != nil {
461463 if _ , ident , _ := goplsastutil .UnpackRecv (decl .Recv .List [0 ].Type ); ident == nil || ! ident .IsExported () {
462- return nil , fmt .Errorf ("cannot add external test for method %s.%s as receiver type is not exported" , ident .Name , decl .Name )
464+ return nil , nil , fmt .Errorf ("cannot add external test for method %s.%s as receiver type is not exported" , ident .Name , decl .Name )
463465 }
464466 }
465467 // TODO(hxjiang): reject if the any input parameter type is unexported.
@@ -469,7 +471,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
469471
470472 testName , err := testName (fn )
471473 if err != nil {
472- return nil , err
474+ return nil , nil , err
473475 }
474476
475477 data := testInfo {
@@ -525,7 +527,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
525527
526528 t , ok := recvType .(typesinternal.NamedOrAlias )
527529 if ! ok {
528- return nil , fmt .Errorf ("the receiver type is neither named type nor alias type" )
530+ return nil , nil , fmt .Errorf ("the receiver type is neither named type nor alias type" )
529531 }
530532
531533 var varName string
@@ -707,7 +709,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
707709 }
708710 importEdits , err := ComputeImportFixEdits (snapshot .Options ().Local , testPGF .Src , importFixes ... )
709711 if err != nil {
710- return nil , fmt .Errorf ("could not compute the import fix edits: %w" , err )
712+ return nil , nil , fmt .Errorf ("could not compute the import fix edits: %w" , err )
711713 }
712714 edits = append (edits , importEdits ... )
713715 } else {
@@ -740,21 +742,41 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
740742
741743 var test bytes.Buffer
742744 if err := testTmpl .Execute (& test , data ); err != nil {
743- return nil , err
745+ return nil , nil , err
744746 }
745747
746748 formatted , err := format .Source (test .Bytes ())
747749 if err != nil {
748- return nil , err
750+ return nil , nil , err
749751 }
750752
751753 edits = append (edits ,
752754 protocol.TextEdit {
753755 Range : eofRange ,
754756 NewText : string (formatted ),
755- })
757+ },
758+ )
759+
760+ // Show the line of generated test function.
761+ {
762+ line := eofRange .Start .Line
763+ for i := range len (edits ) - 1 { // last edits is the func decl
764+ e := edits [i ]
765+ oldLines := e .Range .End .Line - e .Range .Start .Line
766+ newLines := uint32 (strings .Count (e .NewText , "\n " ))
767+ line += (newLines - oldLines )
768+ }
769+ show = & protocol.Location {
770+ URI : testFH .URI (),
771+ Range : protocol.Range {
772+ // Test function template have a new line at beginning.
773+ Start : protocol.Position {Line : line + 1 },
774+ End : protocol.Position {Line : line + 1 },
775+ },
776+ }
777+ }
756778
757- return append (changes , protocol .DocumentChangeEdit (testFH , edits )), nil
779+ return append (changes , protocol .DocumentChangeEdit (testFH , edits )), show , nil
758780}
759781
760782// testName returns the name of the function to use for the new function that
0 commit comments