22package websocketproxy
33
44import (
5+ "context"
56 "fmt"
67 "io"
78 "log"
4748 // If nil, DefaultDialer is used.
4849 Dialer * websocket.Dialer
4950
50- // Done specifies a channel for which all proxied websocket connections
51+ // done specifies a channel for which all proxied websocket connections
5152 // can be closed on demand by closing the channel.
52- Done chan struct {}
53+ done chan struct {}
5354 }
5455
5556 websocketMsg struct {
@@ -186,6 +187,9 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
186187
187188 errClient := make (chan error , 1 )
188189 errBackend := make (chan error , 1 )
190+ if w .done == nil {
191+ w .done = make (chan struct {})
192+ }
189193
190194 replicateWebsocketConn := func (dst , src * websocket.Conn , errc chan error ) {
191195 websocketMsgRcverC := make (chan websocketMsg , 1 )
@@ -214,7 +218,7 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
214218 errc <- err
215219 break
216220 }
217- case <- w .Done :
221+ case <- w .done :
218222 m := websocket .FormatCloseMessage (websocket .CloseGoingAway , "websocketproxy: closing connection" )
219223 dst .WriteMessage (websocket .CloseMessage , m )
220224 break
@@ -234,8 +238,18 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
234238 if e , ok := err .(* websocket.CloseError ); ! ok || e .Code == websocket .CloseAbnormalClosure {
235239 log .Printf ("websocketproxy: Error when copying from client to backend: %v" , err )
236240 }
237- case <- w .Done :
241+ case <- w .done :
242+ }
243+ }
244+
245+ // Shutdown closes ws connections by closing the done channel they are subscribed to.
246+ func (w * WebsocketProxy ) Shutdown (ctx context.Context ) error {
247+ // TODO: support using context for control and return error when applicable
248+ // Currently implemented such that the method signature matches http.Server.Shutdown()
249+ if w .done != nil {
250+ close (w .done )
238251 }
252+ return nil
239253}
240254
241255func copyHeader (dst , src http.Header ) {
0 commit comments