@@ -17,8 +17,6 @@ import (
1717 "sync/atomic"
1818 "time"
1919
20- "golang.org/x/xerrors"
21-
2220 "nhooyr.io/websocket/internal/bpool"
2321)
2422
@@ -66,15 +64,13 @@ type Conn struct {
6664 writeMsgOpcode opcode
6765 writeMsgCtx context.Context
6866 readMsgLeft int64
69- readCloseFrame CloseError
7067
7168 // Used to ensure the previous reader is read till EOF before allowing
7269 // a new one.
7370 activeReader * messageReader
7471 // readFrameLock is acquired to read from bw.
7572 readFrameLock chan struct {}
7673 isReadClosed * atomicInt64
77- isCloseHandshake * atomicInt64
7874 readHeaderBuf []byte
7975 controlPayloadBuf []byte
8076
@@ -102,7 +98,6 @@ func (c *Conn) init() {
10298 c .writeFrameLock = make (chan struct {}, 1 )
10399
104100 c .readFrameLock = make (chan struct {}, 1 )
105- c .isCloseHandshake = & atomicInt64 {}
106101
107102 c .setReadTimeout = make (chan context.Context )
108103 c .setWriteTimeout = make (chan context.Context )
@@ -206,20 +201,20 @@ func (c *Conn) releaseLock(lock chan struct{}) {
206201 }
207202}
208203
209- func (c * Conn ) readTillMsg (ctx context.Context ) (header , error ) {
204+ func (c * Conn ) readTillMsg (ctx context.Context , lock bool ) (header , error ) {
210205 for {
211- h , err := c .readFrameHeader (ctx )
206+ h , err := c .readFrameHeader (ctx , lock )
212207 if err != nil {
213208 return header {}, err
214209 }
215210
216211 if h .rsv1 || h .rsv2 || h .rsv3 {
217- c .Close (StatusProtocolError , fmt .Sprintf ("received header with rsv bits set: %v:%v:%v" , h .rsv1 , h .rsv2 , h .rsv3 ))
212+ c .writeClose (StatusProtocolError , fmt .Sprintf ("received header with rsv bits set: %v:%v:%v" , h .rsv1 , h .rsv2 , h .rsv3 ), false )
218213 return header {}, c .closeErr
219214 }
220215
221216 if h .opcode .controlOp () {
222- err = c .handleControl (ctx , h )
217+ err = c .handleControl (ctx , h , lock )
223218 if err != nil {
224219 return header {}, fmt .Errorf ("failed to handle control frame: %w" , err )
225220 }
@@ -230,18 +225,20 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
230225 case opBinary , opText , opContinuation :
231226 return h , nil
232227 default :
233- c .Close (StatusProtocolError , fmt .Sprintf ("received unknown opcode %v" , h .opcode ))
228+ c .writeClose (StatusProtocolError , fmt .Sprintf ("received unknown opcode %v" , h .opcode ), false )
234229 return header {}, c .closeErr
235230 }
236231 }
237232}
238233
239- func (c * Conn ) readFrameHeader (ctx context.Context ) (header , error ) {
240- err := c .acquireLock (ctx , c .readFrameLock )
241- if err != nil {
242- return header {}, err
234+ func (c * Conn ) readFrameHeader (ctx context.Context , lock bool ) (header , error ) {
235+ if lock {
236+ err := c .acquireLock (ctx , c .readFrameLock )
237+ if err != nil {
238+ return header {}, err
239+ }
240+ defer c .releaseLock (c .readFrameLock )
243241 }
244- defer c .releaseLock (c .readFrameLock )
245242
246243 select {
247244 case <- c .closed :
@@ -273,22 +270,22 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
273270 return h , nil
274271}
275272
276- func (c * Conn ) handleControl (ctx context.Context , h header ) error {
273+ func (c * Conn ) handleControl (ctx context.Context , h header , lock bool ) error {
277274 if h .payloadLength > maxControlFramePayload {
278- c .Close (StatusProtocolError , fmt .Sprintf ("control frame too large at %v bytes" , h .payloadLength ))
275+ c .writeClose (StatusProtocolError , fmt .Sprintf ("control frame too large at %v bytes" , h .payloadLength ), false )
279276 return c .closeErr
280277 }
281278
282279 if ! h .fin {
283- c .Close (StatusProtocolError , "received fragmented control frame" )
280+ c .writeClose (StatusProtocolError , "received fragmented control frame" , false )
284281 return c .closeErr
285282 }
286283
287284 ctx , cancel := context .WithTimeout (ctx , time .Second * 5 )
288285 defer cancel ()
289286
290287 b := c .controlPayloadBuf [:h .payloadLength ]
291- _ , err := c .readFramePayload (ctx , b )
288+ _ , err := c .readFramePayload (ctx , b , lock )
292289 if err != nil {
293290 return err
294291 }
@@ -312,23 +309,24 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
312309 ce , err := parseClosePayload (b )
313310 if err != nil {
314311 err = fmt .Errorf ("received invalid close payload: %w" , err )
315- c .Close (StatusProtocolError , err .Error ())
312+ c .writeClose (StatusProtocolError , err .Error (), false )
316313 return c .closeErr
317314 }
318315
319316 // This ensures the closeErr of the Conn is always the received CloseError
320317 // in case the echo close frame write fails.
321318 // See https://github.com/nhooyr/websocket/issues/109
322- c .setCloseErr (fmt .Errorf ("received close frame: %w" , ce ))
323-
324- c .readCloseFrame = ce
319+ c .setCloseErr (ce )
325320
326321 func () {
327322 ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
328323 defer cancel ()
329324 c .writeControl (ctx , opClose , b )
330325 }()
331326
327+ if ! lock {
328+ c .releaseLock (c .readFrameLock )
329+ }
332330 // We close with nil since the error is already set above.
333331 c .close (nil )
334332 return c .closeErr
@@ -362,16 +360,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
362360// Most users should not need this.
363361func (c * Conn ) Reader (ctx context.Context ) (MessageType , io.Reader , error ) {
364362 if c .isReadClosed .Load () == 1 {
365- return 0 , nil , fmt .Errorf ("websocket connection read closed" )
366- }
367-
368- if c .isCloseHandshake .Load () == 1 {
369- select {
370- case <- ctx .Done ():
371- return 0 , nil , fmt .Errorf ("failed to get reader: %w" , ctx .Err ())
372- case <- c .closed :
373- return 0 , nil , fmt .Errorf ("failed to get reader: %w" , c .closeErr )
374- }
363+ return 0 , nil , errors .New ("websocket connection read closed" )
375364 }
376365
377366 typ , r , err := c .reader (ctx )
@@ -381,23 +370,23 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
381370 return typ , r , nil
382371}
383372
384- func (c * Conn ) reader (ctx context.Context ) (MessageType , io.Reader , error ) {
373+ func (c * Conn ) reader (ctx context.Context ) (_ MessageType , _ io.Reader , err error ) {
385374 if c .activeReader != nil && ! c .readerFrameEOF {
386375 // The only way we know for sure the previous reader is not yet complete is
387376 // if there is an active frame not yet fully read.
388377 // Otherwise, a user may have read the last byte but not the EOF if the EOF
389378 // is in the next frame so we check for that below.
390- return 0 , nil , fmt . Errorf ("previous message not read to completion" )
379+ return 0 , nil , errors . New ("previous message not read to completion" )
391380 }
392381
393- h , err := c .readTillMsg (ctx )
382+ h , err := c .readTillMsg (ctx , true )
394383 if err != nil {
395384 return 0 , nil , err
396385 }
397386
398387 if c .activeReader != nil && ! c .activeReader .eof () {
399388 if h .opcode != opContinuation {
400- c .Close (StatusProtocolError , "received new data message without finishing the previous message" )
389+ c .writeClose (StatusProtocolError , "received new data message without finishing the previous message" , false )
401390 return 0 , nil , c .closeErr
402391 }
403392
@@ -407,12 +396,12 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
407396
408397 c .activeReader = nil
409398
410- h , err = c .readTillMsg (ctx )
399+ h , err = c .readTillMsg (ctx , true )
411400 if err != nil {
412401 return 0 , nil , err
413402 }
414403 } else if h .opcode == opContinuation {
415- c .Close (StatusProtocolError , "received continuation frame not after data or text frame" )
404+ c .writeClose (StatusProtocolError , "received continuation frame not after data or text frame" , false )
416405 return 0 , nil , c .closeErr
417406 }
418407
@@ -458,7 +447,7 @@ func (r *messageReader) read(p []byte) (int, error) {
458447 }
459448
460449 if r .c .readMsgLeft <= 0 {
461- r .c .Close (StatusMessageTooBig , fmt .Sprintf ("read limited at %v bytes" , r .c .msgReadLimit ))
450+ r .c .writeClose (StatusMessageTooBig , fmt .Sprintf ("read limited at %v bytes" , r .c .msgReadLimit ), false )
462451 return 0 , r .c .closeErr
463452 }
464453
@@ -467,13 +456,13 @@ func (r *messageReader) read(p []byte) (int, error) {
467456 }
468457
469458 if r .c .readerFrameEOF {
470- h , err := r .c .readTillMsg (r .c .readerMsgCtx )
459+ h , err := r .c .readTillMsg (r .c .readerMsgCtx , true )
471460 if err != nil {
472461 return 0 , err
473462 }
474463
475464 if h .opcode != opContinuation {
476- r .c .Close (StatusProtocolError , "received new data message without finishing the previous message" )
465+ r .c .writeClose (StatusProtocolError , "received new data message without finishing the previous message" , false )
477466 return 0 , r .c .closeErr
478467 }
479468
@@ -487,7 +476,7 @@ func (r *messageReader) read(p []byte) (int, error) {
487476 p = p [:h .payloadLength ]
488477 }
489478
490- n , err := r .c .readFramePayload (r .c .readerMsgCtx , p )
479+ n , err := r .c .readFramePayload (r .c .readerMsgCtx , p , true )
491480
492481 h .payloadLength -= int64 (n )
493482 r .c .readMsgLeft -= int64 (n )
@@ -512,12 +501,14 @@ func (r *messageReader) read(p []byte) (int, error) {
512501 return n , nil
513502}
514503
515- func (c * Conn ) readFramePayload (ctx context.Context , p []byte ) (int , error ) {
516- err := c .acquireLock (ctx , c .readFrameLock )
517- if err != nil {
518- return 0 , err
504+ func (c * Conn ) readFramePayload (ctx context.Context , p []byte , lock bool ) (int , error ) {
505+ if lock {
506+ err := c .acquireLock (ctx , c .readFrameLock )
507+ if err != nil {
508+ return 0 , err
509+ }
510+ defer c .releaseLock (c .readFrameLock )
519511 }
520- defer c .releaseLock (c .readFrameLock )
521512
522513 select {
523514 case <- c .closed :
@@ -813,14 +804,14 @@ func (c *Conn) writePong(p []byte) error {
813804// Close will unblock all goroutines interacting with the connection once
814805// complete.
815806func (c * Conn ) Close (code StatusCode , reason string ) error {
816- err := c .closeHandshake (code , reason )
807+ err := c .writeClose (code , reason , true )
817808 if err != nil {
818809 return fmt .Errorf ("failed to close websocket connection: %w" , err )
819810 }
820811 return nil
821812}
822813
823- func (c * Conn ) closeHandshake (code StatusCode , reason string ) error {
814+ func (c * Conn ) writeClose (code StatusCode , reason string , handshake bool ) error {
824815 ce := CloseError {
825816 Code : code ,
826817 Reason : reason ,
@@ -838,60 +829,58 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) error {
838829 p , _ = ce .bytes ()
839830 }
840831
832+ // Give the handshake 10 seconds.
841833 ctx , cancel := context .WithTimeout (context .Background (), time .Second * 10 )
842834 defer cancel ()
843835
844- // Ensures the connection is closed if everything below succeeds.
845- // Up here because we must release the read lock first.
846- // nil because of the setCloseErr call below.
847- defer c .close (nil )
848-
849- // CloseErrors sent are made opaque to prevent applications from thinking
850- // they received a given status.
851- sentErr := fmt .Errorf ("sent close frame: %v" , ce )
852- // Other connections should only see this error.
853- c .setCloseErr (sentErr )
854-
855836 err = c .writeControl (ctx , opClose , p )
856837 if err != nil {
857838 return err
858839 }
840+ c .setCloseErr (ce )
841+ defer c .close (nil )
859842
860- // Wait for close frame from peer.
861- err = c .waitClose (ctx )
862- // We didn't read a close frame.
863- if c .readCloseFrame == (CloseError {}) {
864- if ctx .Err () != nil {
865- return xerrors .Errorf ("failed to wait for peer close frame: %w" , ctx .Err ())
866- }
867- // We need to make the err returned from c.waitClose accurate.
868- return xerrors .Errorf ("failed to read peer close frame for unknown reason" )
843+ if handshake {
844+ // Try to wait for close frame peer but don't complain
845+ // if one is not received since we already decided the
846+ // close status of the connection above.
847+ c .waitClose (ctx )
869848 }
849+
870850 return nil
871851}
872852
873853func (c * Conn ) waitClose (ctx context.Context ) error {
854+ err := c .acquireLock (ctx , c .readFrameLock )
855+ if err != nil {
856+ return err
857+ }
858+ defer c .releaseLock (c .readFrameLock )
859+
874860 b := bpool .Get ()
875- buf := b .Bytes ()
876- buf = buf [:cap (buf )]
877861 defer bpool .Put (b )
878862
879- // Prevent reads from user code as we are going to be
880- // discarding all messages so they cannot rely on any ordering.
881- c .isCloseHandshake .Store (1 )
882-
883- // From this point forward, any reader we receive means we are
884- // now the sole readers of the connection and so it is safe
885- // to discard all payloads.
863+ var h header
864+ if c .activeReader != nil && ! c .readerFrameEOF {
865+ h = c .readerMsgHeader
866+ }
886867
887868 for {
888- _ , r , err := c .reader (ctx )
889- if err != nil {
890- return err
869+ for h .payloadLength > 0 {
870+ buf := b .Bytes ()
871+ if int64 (cap (buf )) > h .payloadLength {
872+ buf = buf [:h .payloadLength ]
873+ } else {
874+ buf = buf [:cap (buf )]
875+ }
876+ n , err := c .readFramePayload (ctx , buf , false )
877+ if err != nil {
878+ return err
879+ }
880+ h .payloadLength -= int64 (n )
891881 }
892882
893- // Discard all payloads.
894- _ , err = io .CopyBuffer (ioutil .Discard , r , buf )
883+ h , err = c .readTillMsg (ctx , false )
895884 if err != nil {
896885 return err
897886 }
0 commit comments