@@ -10,7 +10,6 @@ import (
1010 "errors"
1111 "fmt"
1212 "io"
13- "sync"
1413 "time"
1514
1615 "github.com/klauspost/compress/flate"
@@ -71,7 +70,7 @@ type msgWriterState struct {
7170 c * Conn
7271
7372 mu * mu
74- writeMu sync. Mutex
73+ writeMu * mu
7574
7675 ctx context.Context
7776 opcode opcode
@@ -83,8 +82,9 @@ type msgWriterState struct {
8382
8483func newMsgWriterState (c * Conn ) * msgWriterState {
8584 mw := & msgWriterState {
86- c : c ,
87- mu : newMu (c ),
85+ c : c ,
86+ mu : newMu (c ),
87+ writeMu : newMu (c ),
8888 }
8989 return mw
9090}
@@ -155,12 +155,15 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
155155
156156// Write writes the given bytes to the WebSocket connection.
157157func (mw * msgWriterState ) Write (p []byte ) (_ int , err error ) {
158- mw .writeMu .Lock ()
159- defer mw .writeMu .Unlock ()
158+ err = mw .writeMu .lock (mw .ctx )
159+ if err != nil {
160+ return 0 , fmt .Errorf ("failed to write: %w" , err )
161+ }
162+ defer mw .writeMu .unlock ()
160163
161164 defer func () {
162- err = fmt .Errorf ("failed to write: %w" , err )
163165 if err != nil {
166+ err = fmt .Errorf ("failed to write: %w" , err )
164167 mw .c .close (err )
165168 }
166169 }()
@@ -198,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
198201func (mw * msgWriterState ) Close () (err error ) {
199202 defer errd .Wrap (& err , "failed to close writer" )
200203
201- mw .writeMu .Lock ()
202- defer mw .writeMu .Unlock ()
204+ err = mw .writeMu .lock (mw .ctx )
205+ if err != nil {
206+ return err
207+ }
208+ defer mw .writeMu .unlock ()
203209
204210 _ , err = mw .c .writeFrame (mw .ctx , true , mw .flate , mw .opcode , nil )
205211 if err != nil {
@@ -219,7 +225,7 @@ func (mw *msgWriterState) close() {
219225 putBufioWriter (mw .c .bw )
220226 }
221227
222- mw .writeMu .Lock ()
228+ mw .writeMu .forceLock ()
223229 mw .dict .close ()
224230}
225231
@@ -250,7 +256,8 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
250256
251257 defer func () {
252258 if err != nil {
253- c .close (fmt .Errorf ("failed to write frame: %w" , err ))
259+ err = fmt .Errorf ("failed to write frame: %w" , err )
260+ c .close (err )
254261 }
255262 }()
256263
0 commit comments