@@ -49,6 +49,11 @@ type Conn struct {
4949 // Effectively meaning whoever holds it gets to write to bw.
5050 writeFrameLock chan struct {}
5151 writeHeaderBuf []byte
52+ writeHeader * header
53+
54+ // messageWriter state.
55+ writeMsgOpcode opcode
56+ writeMsgCtx context.Context
5257
5358 // Used to ensure the previous reader is read till EOF before allowing
5459 // a new one.
@@ -58,6 +63,12 @@ type Conn struct {
5863 readHeaderBuf []byte
5964 controlPayloadBuf []byte
6065
66+ // messageReader state
67+ readMsgCtx context.Context
68+ readMsgHeader header
69+ readFrameEOF bool
70+ readMaskPos int
71+
6172 setReadTimeout chan context.Context
6273 setWriteTimeout chan context.Context
6374
@@ -81,6 +92,7 @@ func (c *Conn) init() {
8192 c .activePings = make (map [string ]chan <- struct {})
8293
8394 c .writeHeaderBuf = makeWriteHeaderBuf ()
95+ c .writeHeader = & header {}
8496 c .readHeaderBuf = makeReadHeaderBuf ()
8597 c .controlPayloadBuf = make ([]byte , maxControlFramePayload )
8698
@@ -315,15 +327,11 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
315327 if err != nil {
316328 return 0 , nil , xerrors .Errorf ("failed to get reader: %w" , err )
317329 }
318- return typ , & limitedReader {
319- c : c ,
320- r : r ,
321- left : c .msgReadLimit ,
322- }, nil
330+ return typ , r , nil
323331}
324332
325333func (c * Conn ) reader (ctx context.Context ) (MessageType , io.Reader , error ) {
326- if c .previousReader != nil && c . previousReader . h != nil {
334+ if c .previousReader != nil && ! c . readFrameEOF {
327335 // The only way we know for sure the previous reader is not yet complete is
328336 // if there is an active frame not yet fully read.
329337 // Otherwise, a user may have read the last byte but not the EOF if the EOF
@@ -336,7 +344,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
336344 return 0 , nil , err
337345 }
338346
339- if c .previousReader != nil && ! c .previousReader .done {
347+ if c .previousReader != nil && ! c .previousReader .eof {
340348 if h .opcode != opContinuation {
341349 err := xerrors .Errorf ("received new data message without finishing the previous message" )
342350 c .Close (StatusProtocolError , err .Error ())
@@ -347,33 +355,36 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
347355 return 0 , nil , xerrors .Errorf ("previous message not read to completion" )
348356 }
349357
350- c .previousReader .done = true
358+ c .previousReader .eof = true
351359
352- return c .reader (ctx )
360+ h , err = c .readTillMsg (ctx )
361+ if err != nil {
362+ return 0 , nil , err
363+ }
353364 } else if h .opcode == opContinuation {
354365 err := xerrors .Errorf ("received continuation frame not after data or text frame" )
355366 c .Close (StatusProtocolError , err .Error ())
356367 return 0 , nil , err
357368 }
358369
359- r := & messageReader {
360- ctx : ctx ,
361- c : c ,
370+ c .readMsgCtx = ctx
371+ c .readMsgHeader = h
372+ c .readFrameEOF = false
373+ c .readMaskPos = 0
362374
363- h : & h ,
375+ r := & messageReader {
376+ c : c ,
377+ left : c .msgReadLimit ,
364378 }
365379 c .previousReader = r
366380 return MessageType (h .opcode ), r , nil
367381}
368382
369383// messageReader enables reading a data frame from the WebSocket connection.
370384type messageReader struct {
371- ctx context.Context
372- c * Conn
373-
374- h * header
375- maskPos int
376- done bool
385+ c * Conn
386+ left int64
387+ eof bool
377388}
378389
379390// Read reads as many bytes as possible into p.
@@ -391,12 +402,22 @@ func (r *messageReader) Read(p []byte) (int, error) {
391402}
392403
393404func (r * messageReader ) read (p []byte ) (int , error ) {
394- if r .done {
405+ if r .eof {
395406 return 0 , xerrors .Errorf ("cannot use EOFed reader" )
396407 }
397408
398- if r .h == nil {
399- h , err := r .c .readTillMsg (r .ctx )
409+ if r .left <= 0 {
410+ err := xerrors .Errorf ("read limited at %v bytes" , r .c .msgReadLimit )
411+ r .c .Close (StatusMessageTooBig , err .Error ())
412+ return 0 , err
413+ }
414+
415+ if int64 (len (p )) > r .left {
416+ p = p [:r .left ]
417+ }
418+
419+ if r .c .readFrameEOF {
420+ h , err := r .c .readTillMsg (r .c .readMsgCtx )
400421 if err != nil {
401422 return 0 , err
402423 }
@@ -406,38 +427,37 @@ func (r *messageReader) read(p []byte) (int, error) {
406427 r .c .Close (StatusProtocolError , err .Error ())
407428 return 0 , err
408429 }
409- r .h = & h
430+
431+ r .c .readMsgHeader = h
432+ r .c .readFrameEOF = false
433+ r .c .readMaskPos = 0
410434 }
411435
412- if int64 (len (p )) > r .h .payloadLength {
413- p = p [:r .h .payloadLength ]
436+ h := r .c .readMsgHeader
437+ if int64 (len (p )) > h .payloadLength {
438+ p = p [:h .payloadLength ]
414439 }
415440
416- n , err := r .c .readFramePayload (r .ctx , p )
441+ n , err := r .c .readFramePayload (r .c . readMsgCtx , p )
417442
418- r .h .payloadLength -= int64 (n )
419- if r .h .masked {
420- r .maskPos = fastXOR (r .h .maskKey , r .maskPos , p )
443+ h .payloadLength -= int64 (n )
444+ r .left -= int64 (n )
445+ if h .masked {
446+ r .c .readMaskPos = fastXOR (h .maskKey , r .c .readMaskPos , p )
421447 }
448+ r .c .readMsgHeader = h
422449
423450 if err != nil {
424451 return n , err
425452 }
426453
427- if r .h .payloadLength == 0 {
428- fin := r .h .fin
429-
430- // Need to nil this as Reader uses it to check
431- // whether there is active data on the previous reader and
432- // now there isn't.
433- r .h = nil
454+ if h .payloadLength == 0 {
455+ r .c .readFrameEOF = true
434456
435- if fin {
436- r .done = true
457+ if h . fin {
458+ r .eof = true
437459 return n , io .EOF
438460 }
439-
440- r .maskPos = 0
441461 }
442462
443463 return n , nil
@@ -524,10 +544,10 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
524544 if err != nil {
525545 return nil , err
526546 }
547+ c .writeMsgCtx = ctx
548+ c .writeMsgOpcode = opcode (typ )
527549 return & messageWriter {
528- ctx : ctx ,
529- opcode : opcode (typ ),
530- c : c ,
550+ c : c ,
531551 }, nil
532552}
533553
@@ -556,8 +576,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
556576
557577// messageWriter enables writing to a WebSocket connection.
558578type messageWriter struct {
559- ctx context.Context
560- opcode opcode
561579 c * Conn
562580 closed bool
563581}
@@ -575,11 +593,11 @@ func (w *messageWriter) write(p []byte) (int, error) {
575593 if w .closed {
576594 return 0 , xerrors .Errorf ("cannot use closed writer" )
577595 }
578- n , err := w .c .writeFrame (w .ctx , false , w .opcode , p )
596+ n , err := w .c .writeFrame (w .c . writeMsgCtx , false , w .c . writeMsgOpcode , p )
579597 if err != nil {
580598 return n , xerrors .Errorf ("failed to write data frame: %w" , err )
581599 }
582- w .opcode = opContinuation
600+ w .c . writeMsgOpcode = opContinuation
583601 return n , nil
584602}
585603
@@ -599,7 +617,7 @@ func (w *messageWriter) close() error {
599617 }
600618 w .closed = true
601619
602- _ , err := w .c .writeFrame (w .ctx , true , w .opcode , nil )
620+ _ , err := w .c .writeFrame (w .c . writeMsgCtx , true , w .c . writeMsgOpcode , nil )
603621 if err != nil {
604622 return xerrors .Errorf ("failed to write fin frame: %w" , err )
605623 }
@@ -618,20 +636,6 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
618636
619637// writeFrame handles all writes to the connection.
620638func (c * Conn ) writeFrame (ctx context.Context , fin bool , opcode opcode , p []byte ) (int , error ) {
621- h := header {
622- fin : fin ,
623- opcode : opcode ,
624- masked : c .client ,
625- payloadLength : int64 (len (p )),
626- }
627-
628- if c .client {
629- _ , err := io .ReadFull (cryptorand .Reader , h .maskKey [:])
630- if err != nil {
631- return 0 , xerrors .Errorf ("failed to generate masking key: %w" , err )
632- }
633- }
634-
635639 err := c .acquireLock (ctx , c .writeFrameLock )
636640 if err != nil {
637641 return 0 , err
@@ -644,7 +648,19 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
644648 case c .setWriteTimeout <- ctx :
645649 }
646650
647- n , err := c .realWriteFrame (ctx , h , p )
651+ c .writeHeader .fin = fin
652+ c .writeHeader .opcode = opcode
653+ c .writeHeader .masked = c .client
654+ c .writeHeader .payloadLength = int64 (len (p ))
655+
656+ if c .client {
657+ _ , err := io .ReadFull (cryptorand .Reader , c .writeHeader .maskKey [:])
658+ if err != nil {
659+ return 0 , xerrors .Errorf ("failed to generate masking key: %w" , err )
660+ }
661+ }
662+
663+ n , err := c .realWriteFrame (ctx , * c .writeHeader , p )
648664 if err != nil {
649665 return n , err
650666 }
0 commit comments