@@ -97,82 +97,106 @@ func CloseStatus(err error) StatusCode {
9797//
9898// Close will unblock all goroutines interacting with the connection once
9999// complete.
100- func (c * Conn ) Close (code StatusCode , reason string ) error {
101- defer c .wg .Wait ()
102- return c .closeHandshake (code , reason )
100+ func (c * Conn ) Close (code StatusCode , reason string ) (err error ) {
101+ defer errd .Wrap (& err , "failed to close WebSocket" )
102+
103+ if ! c .casClosing () {
104+ err = c .waitGoroutines ()
105+ if err != nil {
106+ return err
107+ }
108+ return net .ErrClosed
109+ }
110+ defer func () {
111+ if errors .Is (err , net .ErrClosed ) {
112+ err = nil
113+ }
114+ }()
115+
116+ err = c .closeHandshake (code , reason )
117+
118+ err2 := c .close ()
119+ if err == nil && err2 != nil {
120+ err = err2
121+ }
122+
123+ err2 = c .waitGoroutines ()
124+ if err == nil && err2 != nil {
125+ err = err2
126+ }
127+
128+ return err
103129}
104130
105131// CloseNow closes the WebSocket connection without attempting a close handshake.
106132// Use when you do not want the overhead of the close handshake.
107133func (c * Conn ) CloseNow () (err error ) {
108- defer c .wg .Wait ()
109134 defer errd .Wrap (& err , "failed to close WebSocket" )
110135
111- if c .isClosed () {
136+ if ! c .casClosing () {
137+ err = c .waitGoroutines ()
138+ if err != nil {
139+ return err
140+ }
112141 return net .ErrClosed
113142 }
143+ defer func () {
144+ if errors .Is (err , net .ErrClosed ) {
145+ err = nil
146+ }
147+ }()
114148
115- c .close (nil )
116- c .closeMu .Lock ()
117- defer c .closeMu .Unlock ()
118- return c .closeErr
119- }
120-
121- func (c * Conn ) closeHandshake (code StatusCode , reason string ) (err error ) {
122- defer errd .Wrap (& err , "failed to close WebSocket" )
123-
124- writeErr := c .writeClose (code , reason )
125- closeHandshakeErr := c .waitCloseHandshake ()
149+ err = c .close ()
126150
127- if writeErr != nil {
128- return writeErr
151+ err2 := c .waitGoroutines ()
152+ if err == nil && err2 != nil {
153+ err = err2
129154 }
155+ return err
156+ }
130157
131- if CloseStatus (closeHandshakeErr ) == - 1 && ! errors .Is (net .ErrClosed , closeHandshakeErr ) {
132- return closeHandshakeErr
158+ func (c * Conn ) closeHandshake (code StatusCode , reason string ) error {
159+ err := c .writeClose (code , reason )
160+ if err != nil {
161+ return err
133162 }
134163
164+ err = c .waitCloseHandshake ()
165+ if CloseStatus (err ) != code {
166+ return err
167+ }
135168 return nil
136169}
137170
138171func (c * Conn ) writeClose (code StatusCode , reason string ) error {
139- c .closeMu .Lock ()
140- wroteClose := c .wroteClose
141- c .wroteClose = true
142- c .closeMu .Unlock ()
143- if wroteClose {
144- return net .ErrClosed
145- }
146-
147172 ce := CloseError {
148173 Code : code ,
149174 Reason : reason ,
150175 }
151176
152177 var p []byte
153- var marshalErr error
178+ var err error
154179 if ce .Code != StatusNoStatusRcvd {
155- p , marshalErr = ce .bytes ()
156- }
157-
158- writeErr := c .writeControl (context .Background (), opClose , p )
159- if CloseStatus (writeErr ) != - 1 {
160- // Not a real error if it's due to a close frame being received.
161- writeErr = nil
180+ p , err = ce .bytes ()
181+ if err != nil {
182+ return err
183+ }
162184 }
163185
164- // We do this after in case there was an error writing the close frame.
165- c . setCloseErr ( fmt . Errorf ( "sent close frame: %w" , ce ) )
186+ ctx , cancel := context . WithTimeout ( context . Background (), time . Second * 5 )
187+ defer cancel ( )
166188
167- if marshalErr != nil {
168- return marshalErr
189+ err = c .writeControl (ctx , opClose , p )
190+ // If the connection closed as we're writing we ignore the error as we might
191+ // have written the close frame, the peer responded and then someone else read it
192+ // and closed the connection.
193+ if err != nil && ! errors .Is (err , net .ErrClosed ) {
194+ return err
169195 }
170- return writeErr
196+ return nil
171197}
172198
173199func (c * Conn ) waitCloseHandshake () error {
174- defer c .close (nil )
175-
176200 ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
177201 defer cancel ()
178202
@@ -208,6 +232,36 @@ func (c *Conn) waitCloseHandshake() error {
208232 }
209233}
210234
235+ func (c * Conn ) waitGoroutines () error {
236+ t := time .NewTimer (time .Second * 15 )
237+ defer t .Stop ()
238+
239+ select {
240+ case <- c .timeoutLoopDone :
241+ case <- t .C :
242+ return errors .New ("failed to wait for timeoutLoop goroutine to exit" )
243+ }
244+
245+ c .closeReadMu .Lock ()
246+ ctx := c .closeReadCtx
247+ c .closeReadMu .Unlock ()
248+ if ctx != nil {
249+ select {
250+ case <- ctx .Done ():
251+ case <- t .C :
252+ return errors .New ("failed to wait for close read goroutine to exit" )
253+ }
254+ }
255+
256+ select {
257+ case <- c .closed :
258+ case <- t .C :
259+ return errors .New ("failed to wait for connection to be closed" )
260+ }
261+
262+ return nil
263+ }
264+
211265func parseClosePayload (p []byte ) (CloseError , error ) {
212266 if len (p ) == 0 {
213267 return CloseError {
@@ -278,16 +332,14 @@ func (ce CloseError) bytesErr() ([]byte, error) {
278332 return buf , nil
279333}
280334
281- func (c * Conn ) setCloseErr ( err error ) {
335+ func (c * Conn ) casClosing () bool {
282336 c .closeMu .Lock ()
283- c .setCloseErrLocked (err )
284- c .closeMu .Unlock ()
285- }
286-
287- func (c * Conn ) setCloseErrLocked (err error ) {
288- if c .closeErr == nil && err != nil {
289- c .closeErr = fmt .Errorf ("WebSocket closed: %w" , err )
337+ defer c .closeMu .Unlock ()
338+ if ! c .closing {
339+ c .closing = true
340+ return true
290341 }
342+ return false
291343}
292344
293345func (c * Conn ) isClosed () bool {
0 commit comments