@@ -16,6 +16,10 @@ import (
1616 "sync"
1717 "sync/atomic"
1818 "time"
19+
20+ "golang.org/x/xerrors"
21+
22+ "nhooyr.io/websocket/internal/bpool"
1923)
2024
2125// Conn represents a WebSocket connection.
@@ -62,13 +66,15 @@ type Conn struct {
6266 writeMsgOpcode opcode
6367 writeMsgCtx context.Context
6468 readMsgLeft int64
69+ readCloseFrame CloseError
6570
6671 // Used to ensure the previous reader is read till EOF before allowing
6772 // a new one.
6873 activeReader * messageReader
6974 // readFrameLock is acquired to read from bw.
7075 readFrameLock chan struct {}
7176 isReadClosed * atomicInt64
77+ isCloseHandshake * atomicInt64
7278 readHeaderBuf []byte
7379 controlPayloadBuf []byte
7480
@@ -96,6 +102,7 @@ func (c *Conn) init() {
96102 c .writeFrameLock = make (chan struct {}, 1 )
97103
98104 c .readFrameLock = make (chan struct {}, 1 )
105+ c .isCloseHandshake = & atomicInt64 {}
99106
100107 c .setReadTimeout = make (chan context.Context )
101108 c .setWriteTimeout = make (chan context.Context )
@@ -230,7 +237,7 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
230237}
231238
232239func (c * Conn ) readFrameHeader (ctx context.Context ) (header , error ) {
233- err := c .acquireLock (context . Background () , c .readFrameLock )
240+ err := c .acquireLock (ctx , c .readFrameLock )
234241 if err != nil {
235242 return header {}, err
236243 }
@@ -308,11 +315,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
308315 c .Close (StatusProtocolError , err .Error ())
309316 return c .closeErr
310317 }
318+
311319 // This ensures the closeErr of the Conn is always the received CloseError
312320 // in case the echo close frame write fails.
313321 // See https://github.com/nhooyr/websocket/issues/109
314322 c .setCloseErr (fmt .Errorf ("received close frame: %w" , ce ))
315- c .writeClose (b , nil )
323+
324+ c .readCloseFrame = ce
325+
326+ func () {
327+ ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
328+ defer cancel ()
329+ c .writeControl (ctx , opClose , b )
330+ }()
331+
332+ // We close with nil since the error is already set above.
333+ c .close (nil )
316334 return c .closeErr
317335 default :
318336 panic (fmt .Sprintf ("websocket: unexpected control opcode: %#v" , h ))
@@ -347,6 +365,15 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
347365 return 0 , nil , fmt .Errorf ("websocket connection read closed" )
348366 }
349367
368+ if c .isCloseHandshake .Load () == 1 {
369+ select {
370+ case <- ctx .Done ():
371+ return 0 , nil , fmt .Errorf ("failed to get reader: %w" , ctx .Err ())
372+ case <- c .closed :
373+ return 0 , nil , fmt .Errorf ("failed to get reader: %w" , c .closeErr )
374+ }
375+ }
376+
350377 typ , r , err := c .reader (ctx )
351378 if err != nil {
352379 return 0 , nil , fmt .Errorf ("failed to get reader: %w" , err )
@@ -772,27 +799,28 @@ func (c *Conn) writePong(p []byte) error {
772799
773800// Close closes the WebSocket connection with the given status code and reason.
774801//
775- // It will write a WebSocket close frame with a timeout of 5 seconds.
802+ // It will write a WebSocket close frame and then wait for the peer to respond
803+ // with its own close frame. The entire process must complete within 10 seconds.
804+ // Thus, it implements the full WebSocket close handshake.
805+ //
776806// The connection can only be closed once. Additional calls to Close
777807// are no-ops.
778808//
779- // This does not perform a WebSocket close handshake.
780- // See https://github.com/nhooyr/websocket/issues/103 for details on why.
781- //
782809// The maximum length of reason must be 125 bytes otherwise an internal
783810// error will be sent to the peer. For this reason, you should avoid
784811// sending a dynamic reason.
785812//
786- // Close will unblock all goroutines interacting with the connection.
813+ // Close will unblock all goroutines interacting with the connection once
814+ // complete.
787815func (c * Conn ) Close (code StatusCode , reason string ) error {
788- err := c .exportedClose (code , reason )
816+ err := c .closeHandshake (code , reason )
789817 if err != nil {
790818 return fmt .Errorf ("failed to close websocket connection: %w" , err )
791819 }
792820 return nil
793821}
794822
795- func (c * Conn ) exportedClose (code StatusCode , reason string ) error {
823+ func (c * Conn ) closeHandshake (code StatusCode , reason string ) error {
796824 ce := CloseError {
797825 Code : code ,
798826 Reason : reason ,
@@ -810,34 +838,64 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
810838 p , _ = ce .bytes ()
811839 }
812840
841+ ctx , cancel := context .WithTimeout (context .Background (), time .Second * 10 )
842+ defer cancel ()
843+
844+ // Ensures the connection is closed if everything below succeeds.
845+ // Up here because we must release the read lock first.
846+ // nil because of the setCloseErr call below.
847+ defer c .close (nil )
848+
813849 // CloseErrors sent are made opaque to prevent applications from thinking
814850 // they received a given status.
815851 sentErr := fmt .Errorf ("sent close frame: %v" , ce )
816- err = c .writeClose (p , sentErr )
852+ // Other connections should only see this error.
853+ c .setCloseErr (sentErr )
854+
855+ err = c .writeControl (ctx , opClose , p )
817856 if err != nil {
818857 return err
819858 }
820859
821- if ! errors .Is (c .closeErr , sentErr ) {
822- return c .closeErr
860+ // Wait for close frame from peer.
861+ err = c .waitClose (ctx )
862+ // We didn't read a close frame.
863+ if c .readCloseFrame == (CloseError {}) {
864+ if ctx .Err () != nil {
865+ return xerrors .Errorf ("failed to wait for peer close frame: %w" , ctx .Err ())
866+ }
867+ // We need to make the err returned from c.waitClose accurate.
868+ return xerrors .Errorf ("failed to read peer close frame for unknown reason" )
823869 }
824-
825870 return nil
826871}
827872
828- func (c * Conn ) writeClose (p []byte , cerr error ) error {
829- ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
830- defer cancel ()
873+ func (c * Conn ) waitClose (ctx context.Context ) error {
874+ b := bpool .Get ()
875+ buf := b .Bytes ()
876+ buf = buf [:cap (buf )]
877+ defer bpool .Put (b )
831878
832- // If this fails, the connection had to have died.
833- err := c .writeControl (ctx , opClose , p )
834- if err != nil {
835- return err
836- }
879+ // Prevent reads from user code as we are going to be
880+ // discarding all messages so they cannot rely on any ordering.
881+ c .isCloseHandshake .Store (1 )
837882
838- c .close (cerr )
883+ // From this point forward, any reader we receive means we are
884+ // now the sole readers of the connection and so it is safe
885+ // to discard all payloads.
839886
840- return nil
887+ for {
888+ _ , r , err := c .reader (ctx )
889+ if err != nil {
890+ return err
891+ }
892+
893+ // Discard all payloads.
894+ _ , err = io .CopyBuffer (ioutil .Discard , r , buf )
895+ if err != nil {
896+ return err
897+ }
898+ }
841899}
842900
843901// Ping sends a ping to the peer and waits for a pong.
0 commit comments