@@ -10,7 +10,6 @@ import (
1010 "errors"
1111 "fmt"
1212 "io"
13- "sync"
1413 "time"
1514
1615 "github.com/klauspost/compress/flate"
@@ -71,7 +70,7 @@ type msgWriterState struct {
7170 c * Conn
7271
7372 mu * mu
74- writeMu sync. Mutex
73+ writeMu * mu
7574
7675 ctx context.Context
7776 opcode opcode
@@ -83,8 +82,9 @@ type msgWriterState struct {
8382
8483func newMsgWriterState (c * Conn ) * msgWriterState {
8584 mw := & msgWriterState {
86- c : c ,
87- mu : newMu (c ),
85+ c : c ,
86+ mu : newMu (c ),
87+ writeMu : newMu (c ),
8888 }
8989 return mw
9090}
@@ -155,10 +155,18 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
155155
156156// Write writes the given bytes to the WebSocket connection.
157157func (mw * msgWriterState ) Write (p []byte ) (_ int , err error ) {
158- defer errd .Wrap (& err , "failed to write" )
158+ err = mw .writeMu .lock (mw .ctx )
159+ if err != nil {
160+ return 0 , fmt .Errorf ("failed to write: %w" , err )
161+ }
162+ defer mw .writeMu .unlock ()
159163
160- mw .writeMu .Lock ()
161- defer mw .writeMu .Unlock ()
164+ defer func () {
165+ if err != nil {
166+ err = fmt .Errorf ("failed to write: %w" , err )
167+ mw .c .close (err )
168+ }
169+ }()
162170
163171 if mw .c .flate () {
164172 // Only enables flate if the length crosses the
@@ -193,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
193201func (mw * msgWriterState ) Close () (err error ) {
194202 defer errd .Wrap (& err , "failed to close writer" )
195203
196- mw .writeMu .Lock ()
197- defer mw .writeMu .Unlock ()
204+ err = mw .writeMu .lock (mw .ctx )
205+ if err != nil {
206+ return err
207+ }
208+ defer mw .writeMu .unlock ()
198209
199210 _ , err = mw .c .writeFrame (mw .ctx , true , mw .flate , mw .opcode , nil )
200211 if err != nil {
@@ -214,7 +225,7 @@ func (mw *msgWriterState) close() {
214225 putBufioWriter (mw .c .bw )
215226 }
216227
217- mw .writeMu .Lock ()
228+ mw .writeMu .forceLock ()
218229 mw .dict .close ()
219230}
220231
@@ -230,8 +241,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
230241}
231242
232243// frame handles all writes to the connection.
233- func (c * Conn ) writeFrame (ctx context.Context , fin bool , flate bool , opcode opcode , p []byte ) (int , error ) {
234- err : = c .writeFrameMu .lock (ctx )
244+ func (c * Conn ) writeFrame (ctx context.Context , fin bool , flate bool , opcode opcode , p []byte ) (_ int , err error ) {
245+ err = c .writeFrameMu .lock (ctx )
235246 if err != nil {
236247 return 0 , err
237248 }
@@ -243,6 +254,13 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
243254 case c .writeTimeout <- ctx :
244255 }
245256
257+ defer func () {
258+ if err != nil {
259+ err = fmt .Errorf ("failed to write frame: %w" , err )
260+ c .close (err )
261+ }
262+ }()
263+
246264 c .writeHeader .fin = fin
247265 c .writeHeader .opcode = opcode
248266 c .writeHeader .payloadLength = int64 (len (p ))
0 commit comments