@@ -11,14 +11,14 @@ import (
1111 "runtime"
1212 "strconv"
1313 "sync"
14- "sync/atomic"
1514 "time"
1615
1716 "golang.org/x/xerrors"
1817)
1918
2019// Conn represents a WebSocket connection.
21- // All methods may be called concurrently.
20+ // All methods may be called concurrently except for Reader, Read
21+ // and SetReadLimit.
2222//
2323// Please be sure to call Close on the connection when you
2424// are finished with it to release the associated resources.
@@ -29,29 +29,30 @@ type Conn struct {
2929 closer io.Closer
3030 client bool
3131
32- // In bytes.
32+ // read limit for a message in bytes.
3333 msgReadLimit int64
3434
3535 closeOnce sync.Once
3636 closeErr error
3737 closed chan struct {}
3838
39- // writeMsgLock is acquired to write a multi frame message.
40- writeMsgLock chan struct {}
39+ // writeMsgLock is acquired to write a data message.
40+ writeMsgLock chan struct {}
4141 // writeFrameLock is acquired to write a single frame.
4242 // Effectively meaning whoever holds it gets to write to bw.
4343 writeFrameLock chan struct {}
4444
45- // readMsgLock is acquired to read a message with Reader.
46- readMsgLock chan struct {}
45+ // Used to ensure the previous reader is read till EOF before allowing
46+ // a new one.
47+ previousReader * messageReader
4748 // readFrameLock is acquired to read from bw.
4849 readFrameLock chan struct {}
4950 // readMsg is used by messageReader to receive frames from
5051 // readLoop.
51- readMsg chan header
52+ readMsg chan header
5253 // readMsgDone is used to tell the readLoop to continue after
5354 // messageReader has read a frame.
54- readMsgDone chan struct {}
55+ readMsgDone chan struct {}
5556
5657 setReadTimeout chan context.Context
5758 setWriteTimeout chan context.Context
@@ -129,7 +130,6 @@ func (c *Conn) init() {
129130 c .writeMsgLock = make (chan struct {}, 1 )
130131 c .writeFrameLock = make (chan struct {}, 1 )
131132
132- c .readMsgLock = make (chan struct {}, 1 )
133133 c .readFrameLock = make (chan struct {}, 1 )
134134 c .readMsg = make (chan header )
135135 c .readMsgDone = make (chan struct {})
@@ -271,7 +271,7 @@ func (c *Conn) handleControl(h header) {
271271
272272 b := make ([]byte , h .payloadLength )
273273
274- _ , err := c .readPayload (ctx , b )
274+ _ , err := c .readFramePayload (ctx , b )
275275 if err != nil {
276276 return
277277 }
@@ -427,13 +427,11 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
427427 defer cancel ()
428428
429429 err := c .writeControl (ctx , opClose , p )
430-
431- c .close (cerr )
432-
433430 if err != nil {
434431 return err
435432 }
436433
434+ c .close (cerr )
437435 if ! xerrors .Is (c .closeErr , cerr ) {
438436 return c .closeErr
439437 }
@@ -444,6 +442,16 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
444442func (c * Conn ) acquireLock (ctx context.Context , lock chan struct {}) error {
445443 select {
446444 case <- ctx .Done ():
445+ var err error
446+ switch lock {
447+ case c .writeFrameLock , c .writeMsgLock :
448+ err = xerrors .Errorf ("could not acquire write lock: %v" , ctx .Err ())
449+ case c .readFrameLock :
450+ err = xerrors .Errorf ("could not acquire read lock: %v" , ctx .Err ())
451+ default :
452+ panic (fmt .Sprintf ("websocket: failed to acquire unknown lock: %v" , ctx .Err ()))
453+ }
454+ c .close (err )
447455 return ctx .Err ()
448456 case <- c .closed :
449457 return c .closeErr
@@ -490,7 +498,7 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
490498// Read is a convenience method to read a single message from the connection.
491499//
492500// See the Reader method if you want to be able to reuse buffers or want to stream a message.
493- // The docs on Reader apply to this metohd as well.
501+ // The docs on Reader apply to this method as well.
494502//
495503// This is an experimental API, please let me know how you feel about it in
496504// https://github.com/nhooyr/websocket/issues/62
@@ -501,11 +509,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
501509 }
502510
503511 b , err := ioutil .ReadAll (r )
504- if err != nil {
505- return typ , b , err
506- }
507-
508- return typ , b , nil
512+ return typ , b , err
509513}
510514
511515// Write is a convenience method to write a message to the connection.
@@ -531,10 +535,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
531535 defer c .releaseLock (c .writeMsgLock )
532536
533537 err = c .writeFrame (ctx , true , opcode (typ ), p )
534- if err != nil {
535- return err
536- }
537- return nil
538+ return err
538539}
539540
540541// messageWriter enables writing to a WebSocket connection.
@@ -591,41 +592,34 @@ func (w *messageWriter) close() error {
591592 return nil
592593}
593594
594- // Reader will wait until there is a WebSocket data message to read from the connection.
595+ // Reader waits until there is a WebSocket data message to read
596+ // from the connection.
595597// It returns the type of the message and a reader to read it.
596598// The passed context will also bound the reader.
599+ // Ensure you read to EOF otherwise the connection will hang.
597600//
598601// Control (ping, pong, close) frames will be handled automatically
599602// in a separate goroutine so if you do not expect any data messages,
600603// you do not need to read from the connection. However, if the peer
601604// sends a data message, further pings, pongs and close frames will not
602605// be read if you do not read the message from the connection.
603606//
604- // If you do not read from the reader till EOF, nothing further will be read from the connection.
605- // Only one reader can be open at a time, multiple calls will block until the previous reader
606- // is read to completion.
607- // TODO remove concurrent reads.
607+ // Only one Reader may be open at a time.
608608func (c * Conn ) Reader (ctx context.Context ) (MessageType , io.Reader , error ) {
609- // We could handle the case of json.Decoder where the message may not be read
610- // till EOF but would still be read till the end of data. E.g. if the other side
611- // sends a fin frame after the message, we could allow the code to continue and
612- // just pick off but the code for that gets complicated and if there is real data
613- // after the JSON object, Reader would block until the timeout is hit
614609 typ , r , err := c .reader (ctx )
615610 if err != nil {
616611 return 0 , nil , xerrors .Errorf ("failed to get reader: %w" , err )
617612 }
618613 return typ , & limitedReader {
619614 c : c ,
620615 r : r ,
621- left : atomic . LoadInt64 ( & c .msgReadLimit ) ,
616+ left : c .msgReadLimit ,
622617 }, nil
623618}
624619
625- func (c * Conn ) reader (ctx context.Context ) (_ MessageType , _ io.Reader , err error ) {
626- err = c .acquireLock (ctx , c .readMsgLock )
627- if err != nil {
628- return 0 , nil , err
620+ func (c * Conn ) reader (ctx context.Context ) (MessageType , io.Reader , error ) {
621+ if c .previousReader .h != nil && c .previousReader .h .payloadLength > 0 {
622+ return 0 , nil , xerrors .Errorf ("previous message not read to completion" )
629623 }
630624
631625 select {
@@ -634,26 +628,42 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
634628 case <- ctx .Done ():
635629 return 0 , nil , ctx .Err ()
636630 case h := <- c .readMsg :
637- if h .opcode == opContinuation {
631+ if c .previousReader != nil && ! c .previousReader .done {
632+ if h .opcode != opContinuation {
633+ err := xerrors .Errorf ("received new data message without finishing the previous message" )
634+ c .Close (StatusProtocolError , err .Error ())
635+ return 0 , nil , err
636+ }
637+
638+ if ! h .fin || h .payloadLength > 0 {
639+ return 0 , nil , xerrors .Errorf ("previous message not read to completion" )
640+ }
641+
642+ c .previousReader .done = true
643+ return c .reader (ctx )
644+ } else if h .opcode == opContinuation {
638645 err := xerrors .Errorf ("received continuation frame not after data or text frame" )
639646 c .Close (StatusProtocolError , err .Error ())
640647 return 0 , nil , err
641648 }
642- return MessageType ( h . opcode ), & messageReader {
649+ r := & messageReader {
643650 ctx : ctx ,
644651 h : & h ,
645652 c : c ,
646- }, nil
653+ }
654+ c .previousReader = r
655+ return MessageType (h .opcode ), r , nil
647656 }
648657}
649658
650659// messageReader enables reading a data frame from the WebSocket connection.
651660type messageReader struct {
652- ctx context.Context
653- maskPos int
661+ ctx context.Context
662+ c * Conn
663+
654664 h * header
655- c * Conn
656- eofed bool
665+ maskPos int
666+ done bool
657667}
658668
659669// Read reads as many bytes as possible into p.
@@ -665,13 +675,15 @@ func (r *messageReader) Read(p []byte) (int, error) {
665675 if xerrors .Is (err , io .EOF ) {
666676 return n , io .EOF
667677 }
668- return n , xerrors .Errorf ("failed to read: %w" , err )
678+ err = xerrors .Errorf ("failed to read: %w" , err )
679+ r .c .close (err )
680+ return n , err
669681 }
670682 return n , nil
671683}
672684
673685func (r * messageReader ) read (p []byte ) (int , error ) {
674- if r .eofed {
686+ if r .done {
675687 return 0 , xerrors .Errorf ("cannot use EOFed reader" )
676688 }
677689
@@ -695,16 +707,14 @@ func (r *messageReader) read(p []byte) (int, error) {
695707 p = p [:r .h .payloadLength ]
696708 }
697709
698- n , err := r .readPayload ( p )
710+ n , err := r .c . readFramePayload ( r . ctx , p )
699711
700712 r .h .payloadLength -= int64 (n )
701713 if r .h .masked {
702714 r .maskPos = fastXOR (r .h .maskKey , r .maskPos , p )
703715 }
704716
705717 if err != nil {
706- err := xerrors .Errorf ("failed to read frame payload: %w" , err )
707- r .c .close (err )
708718 return n , err
709719 }
710720
@@ -716,8 +726,7 @@ func (r *messageReader) read(p []byte) (int, error) {
716726 }
717727
718728 if r .h .fin {
719- r .eofed = true
720- r .c .releaseLock (r .c .readMsgLock )
729+ r .done = true
721730 return n , io .EOF
722731 }
723732
@@ -728,16 +737,7 @@ func (r *messageReader) read(p []byte) (int, error) {
728737 return n , nil
729738}
730739
731- func (c * Conn ) isClosed () bool {
732- select {
733- case <- c .closed :
734- return true
735- default :
736- return false
737- }
738- }
739-
740- func (c * Conn ) readPayload (ctx context.Context , p []byte ) (int , error ) {
740+ func (c * Conn ) readFramePayload (ctx context.Context , p []byte ) (int , error ) {
741741 err := c .acquireLock (ctx , c .readFrameLock )
742742 if err != nil {
743743 return 0 , err
@@ -779,7 +779,7 @@ func (c *Conn) readPayload(ctx context.Context, p []byte) (int, error) {
779779//
780780// When the limit is hit, the connection will be closed with StatusPolicyViolation.
781781func (c * Conn ) SetReadLimit (n int64 ) {
782- atomic . StoreInt64 ( & c .msgReadLimit , n )
782+ c .msgReadLimit = n
783783}
784784
785785func init () {
@@ -794,7 +794,9 @@ func init() {
794794func (c * Conn ) Ping (ctx context.Context ) error {
795795 err := c .ping (ctx )
796796 if err != nil {
797- return xerrors .Errorf ("failed to ping: %w" , err )
797+ err = xerrors .Errorf ("failed to ping: %w" , err )
798+ c .close (err )
799+ return err
798800 }
799801 return nil
800802}
0 commit comments