@@ -3,6 +3,7 @@ package websocket
33import (
44 "bufio"
55 "context"
6+ cryptorand "crypto/rand"
67 "fmt"
78 "io"
89 "io/ioutil"
@@ -26,8 +27,11 @@ type Conn struct {
2627 subprotocol string
2728 br * bufio.Reader
2829 bw * bufio.Writer
29- closer io.Closer
30- client bool
30+ // writeBuf is used for masking, its the buffer in bufio.Writer.
31+ // Only used by the client.
32+ writeBuf []byte
33+ closer io.Closer
34+ client bool
3135
3236 // read limit for a message in bytes.
3337 msgReadLimit int64
@@ -581,22 +585,22 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
581585// See the Writer method if you want to stream a message. The docs on Writer
582586// regarding concurrency also apply to this method.
583587func (c * Conn ) Write (ctx context.Context , typ MessageType , p []byte ) error {
584- err := c .write (ctx , typ , p )
588+ _ , err := c .write (ctx , typ , p )
585589 if err != nil {
586590 return xerrors .Errorf ("failed to write msg: %w" , err )
587591 }
588592 return nil
589593}
590594
591- func (c * Conn ) write (ctx context.Context , typ MessageType , p []byte ) error {
595+ func (c * Conn ) write (ctx context.Context , typ MessageType , p []byte ) ( int , error ) {
592596 err := c .acquireLock (ctx , c .writeMsgLock )
593597 if err != nil {
594- return err
598+ return 0 , err
595599 }
596600 defer c .releaseLock (c .writeMsgLock )
597601
598- err = c .writeFrame (ctx , true , opcode (typ ), p )
599- return err
602+ n , err : = c .writeFrame (ctx , true , opcode (typ ), p )
603+ return n , err
600604}
601605
602606// messageWriter enables writing to a WebSocket connection.
@@ -620,12 +624,12 @@ func (w *messageWriter) write(p []byte) (int, error) {
620624 if w .closed {
621625 return 0 , xerrors .Errorf ("cannot use closed writer" )
622626 }
623- err := w .c .writeFrame (w .ctx , false , w .opcode , p )
627+ n , err := w .c .writeFrame (w .ctx , false , w .opcode , p )
624628 if err != nil {
625- return 0 , xerrors .Errorf ("failed to write data frame: %w" , err )
629+ return n , xerrors .Errorf ("failed to write data frame: %w" , err )
626630 }
627631 w .opcode = opContinuation
628- return len ( p ) , nil
632+ return n , nil
629633}
630634
631635// Close flushes the frame to the connection.
@@ -644,7 +648,7 @@ func (w *messageWriter) close() error {
644648 }
645649 w .closed = true
646650
647- err := w .c .writeFrame (w .ctx , true , w .opcode , nil )
651+ _ , err := w .c .writeFrame (w .ctx , true , w .opcode , nil )
648652 if err != nil {
649653 return xerrors .Errorf ("failed to write fin frame: %w" , err )
650654 }
@@ -654,34 +658,40 @@ func (w *messageWriter) close() error {
654658}
655659
656660func (c * Conn ) writeControl (ctx context.Context , opcode opcode , p []byte ) error {
657- err := c .writeFrame (ctx , true , opcode , p )
661+ _ , err := c .writeFrame (ctx , true , opcode , p )
658662 if err != nil {
659663 return xerrors .Errorf ("failed to write control frame: %w" , err )
660664 }
661665 return nil
662666}
663667
664668// writeFrame handles all writes to the connection.
665- // We never mask inside here because our mask key is always 0,0,0,0.
666- // See comment on secWebSocketKey for why.
667- func (c * Conn ) writeFrame (ctx context.Context , fin bool , opcode opcode , p []byte ) error {
669+ func (c * Conn ) writeFrame (ctx context.Context , fin bool , opcode opcode , p []byte ) (int , error ) {
668670 h := header {
669671 fin : fin ,
670672 opcode : opcode ,
671673 masked : c .client ,
672674 payloadLength : int64 (len (p )),
673675 }
676+
677+ if c .client {
678+ _ , err := io .ReadFull (cryptorand .Reader , h .maskKey [:])
679+ if err != nil {
680+ return 0 , xerrors .Errorf ("failed to generate masking key: %w" , err )
681+ }
682+ }
683+
674684 b2 := marshalHeader (h )
675685
676686 err := c .acquireLock (ctx , c .writeFrameLock )
677687 if err != nil {
678- return err
688+ return 0 , err
679689 }
680690 defer c .releaseLock (c .writeFrameLock )
681691
682692 select {
683693 case <- c .closed :
684- return c .closeErr
694+ return 0 , c .closeErr
685695 case c .setWriteTimeout <- ctx :
686696 }
687697
@@ -705,29 +715,61 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
705715
706716 _ , err = c .bw .Write (b2 )
707717 if err != nil {
708- return writeErr (err )
709- }
710- _ , err = c .bw .Write (p )
711- if err != nil {
712- return writeErr (err )
718+ return 0 , writeErr (err )
719+ }
720+
721+ var n int
722+ if c .client {
723+ var keypos int
724+ for len (p ) > 0 {
725+ if c .bw .Available () == 0 {
726+ err = c .bw .Flush ()
727+ if err != nil {
728+ return n , writeErr (err )
729+ }
730+ }
731+
732+ // Start of next write in the buffer.
733+ i := c .bw .Buffered ()
734+
735+ p2 := p
736+ if len (p ) > c .bw .Available () {
737+ p2 = p [:c .bw .Available ()]
738+ }
739+
740+ n2 , err := c .bw .Write (p2 )
741+ if err != nil {
742+ return n , writeErr (err )
743+ }
744+
745+ keypos = fastXOR (h .maskKey , keypos , c .writeBuf [i :i + n2 ])
746+
747+ p = p [n2 :]
748+ n += n2
749+ }
750+ } else {
751+ n , err = c .bw .Write (p )
752+ if err != nil {
753+ return n , writeErr (err )
754+ }
713755 }
714756
715757 if fin {
716758 err = c .bw .Flush ()
717759 if err != nil {
718- return writeErr (err )
760+ return n , writeErr (err )
719761 }
720762 }
721763
722764 // We already finished writing, no need to potentially brick the connection if
723765 // the context expires.
724766 select {
725767 case <- c .closed :
726- return c .closeErr
768+ return n , c .closeErr
727769 case c .setWriteTimeout <- context .Background ():
728770 }
729771
730- return nil
772+ return n , nil
731773}
732774
733775func (c * Conn ) writePong (p []byte ) error {
@@ -842,3 +884,23 @@ func (c *Conn) ping(ctx context.Context) error {
842884 return nil
843885 }
844886}
887+
888+ type writerFunc func (p []byte ) (int , error )
889+
890+ func (f writerFunc ) Write (p []byte ) (int , error ) {
891+ return f (p )
892+ }
893+
894+ // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
895+ // and stores it in c.writeBuf.
896+ func (c * Conn ) extractBufioWriterBuf (w io.Writer ) {
897+ c .bw .Reset (writerFunc (func (p2 []byte ) (int , error ) {
898+ c .writeBuf = p2 [:cap (p2 )]
899+ return len (p2 ), nil
900+ }))
901+
902+ c .bw .WriteByte (0 )
903+ c .bw .Flush ()
904+
905+ c .bw .Reset (w )
906+ }
0 commit comments