@@ -109,7 +109,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
109109
110110 if ! c .flate () {
111111 defer c .msgWriter .mu .unlock ()
112- return c .writeFrame (ctx , true , false , c .msgWriter .opcode , p )
112+ return c .writeFrame (true , ctx , true , false , c .msgWriter .opcode , p )
113113 }
114114
115115 n , err := mw .Write (p )
@@ -159,6 +159,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
159159 defer func () {
160160 if err != nil {
161161 err = fmt .Errorf ("failed to write: %w" , err )
162+ mw .writeMu .unlock ()
162163 mw .c .close (err )
163164 }
164165 }()
@@ -179,7 +180,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
179180}
180181
181182func (mw * msgWriter ) write (p []byte ) (int , error ) {
182- n , err := mw .c .writeFrame (mw .ctx , false , mw .flate , mw .opcode , p )
183+ n , err := mw .c .writeFrame (true , mw .ctx , false , mw .flate , mw .opcode , p )
183184 if err != nil {
184185 return n , fmt .Errorf ("failed to write data frame: %w" , err )
185186 }
@@ -191,25 +192,25 @@ func (mw *msgWriter) write(p []byte) (int, error) {
191192func (mw * msgWriter ) Close () (err error ) {
192193 defer errd .Wrap (& err , "failed to close writer" )
193194
194- if mw .closed {
195- return errors .New ("writer already closed" )
196- }
197- mw .closed = true
198-
199195 err = mw .writeMu .lock (mw .ctx )
200196 if err != nil {
201197 return err
202198 }
203199 defer mw .writeMu .unlock ()
204200
201+ if mw .closed {
202+ return errors .New ("writer already closed" )
203+ }
204+ mw .closed = true
205+
205206 if mw .flate {
206207 err = mw .flateWriter .Flush ()
207208 if err != nil {
208209 return fmt .Errorf ("failed to flush flate: %w" , err )
209210 }
210211 }
211212
212- _ , err = mw .c .writeFrame (mw .ctx , true , mw .flate , mw .opcode , nil )
213+ _ , err = mw .c .writeFrame (true , mw .ctx , true , mw .flate , mw .opcode , nil )
213214 if err != nil {
214215 return fmt .Errorf ("failed to write fin frame: %w" , err )
215216 }
@@ -235,15 +236,15 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
235236 ctx , cancel := context .WithTimeout (ctx , time .Second * 5 )
236237 defer cancel ()
237238
238- _ , err := c .writeFrame (ctx , true , false , opcode , p )
239+ _ , err := c .writeFrame (false , ctx , true , false , opcode , p )
239240 if err != nil {
240241 return fmt .Errorf ("failed to write control frame %v: %w" , opcode , err )
241242 }
242243 return nil
243244}
244245
245246// frame handles all writes to the connection.
246- func (c * Conn ) writeFrame (ctx context.Context , fin bool , flate bool , opcode opcode , p []byte ) (_ int , err error ) {
247+ func (c * Conn ) writeFrame (msgWriter bool , ctx context.Context , fin bool , flate bool , opcode opcode , p []byte ) (_ int , err error ) {
247248 err = c .writeFrameMu .lock (ctx )
248249 if err != nil {
249250 return 0 , err
@@ -283,6 +284,10 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
283284 err = ctx .Err ()
284285 default :
285286 }
287+ c .writeFrameMu .unlock ()
288+ if msgWriter {
289+ c .msgWriter .writeMu .unlock ()
290+ }
286291 c .close (err )
287292 err = fmt .Errorf ("failed to write frame: %w" , err )
288293 }
0 commit comments