@@ -383,7 +383,11 @@ func TestHandshake(t *testing.T) {
383383 }
384384 defer c .Close (websocket .StatusInternalError , "" )
385385
386- go c .Reader (r .Context ())
386+ errc := make (chan error , 1 )
387+ go func () {
388+ _ , _ , err2 := c .Read (r .Context ())
389+ errc <- err2
390+ }()
387391
388392 err = c .Ping (r .Context ())
389393 if err != nil {
@@ -395,8 +399,12 @@ func TestHandshake(t *testing.T) {
395399 return err
396400 }
397401
398- c .Close (websocket .StatusNormalClosure , "" )
399- return nil
402+ err = <- errc
403+ var ce websocket.CloseError
404+ if xerrors .As (err , & ce ) && ce .Code == websocket .StatusNormalClosure {
405+ return nil
406+ }
407+ return xerrors .Errorf ("unexpected error: %w" , err )
400408 },
401409 client : func (ctx context.Context , u string ) error {
402410 c , _ , err := websocket .Dial (ctx , u , websocket.DialOptions {})
@@ -405,19 +413,30 @@ func TestHandshake(t *testing.T) {
405413 }
406414 defer c .Close (websocket .StatusInternalError , "" )
407415
408- errc := make (chan error , 1 )
416+ // We read a message from the connection and then keep reading until
417+ // the Ping completes.
418+ done := make (chan struct {})
409419 go func () {
410- errc <- c .Ping (ctx )
420+ _ , _ , err := c .Read (ctx )
421+ if err != nil {
422+ c .Close (websocket .StatusInternalError , err .Error ())
423+ return
424+ }
425+
426+ close (done )
427+
428+ c .Read (ctx )
411429 }()
412430
413- _ , _ , err = c .Read (ctx )
431+ err = c .Ping (ctx )
414432 if err != nil {
415433 return err
416434 }
417435
418- err = <- errc
436+ <- done
437+
419438 c .Close (websocket .StatusNormalClosure , "" )
420- return err
439+ return nil
421440 },
422441 },
423442 {
0 commit comments