@@ -2,7 +2,6 @@ package websocket
22
33import (
44 "context"
5- "errors"
65 "fmt"
76 "net/http"
87 "sync"
@@ -17,98 +16,92 @@ import (
1716// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods.
1817// It's required as net/http's Shutdown and Close methods do not keep track of WebSocket
1918// connections.
19+ //
20+ // Make sure to Close or Shutdown the *http.Server first as you don't want to accept
21+ // any new connections while the existing websockets are being shut down.
2022type Grace struct {
21- mu sync.Mutex
22- closed bool
23- shuttingDown bool
24- conns map [* Conn ]struct {}
23+ handlersMu sync.Mutex
24+ closing bool
25+ handlers map [context.Context ]context.CancelFunc
2526}
2627
2728// Handler returns a handler that wraps around h to record
2829// all WebSocket connections accepted.
2930//
3031// Use Close or Shutdown to gracefully close recorded connections.
32+ // Make sure to Close or Shutdown the *http.Server first.
3133func (g * Grace ) Handler (h http.Handler ) http.Handler {
3234 return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
33- ctx := context .WithValue (r .Context (), gracefulContextKey {}, g )
35+ ctx , cancel := context .WithCancel (r .Context ())
36+ defer cancel ()
37+
3438 r = r .WithContext (ctx )
39+
40+ ok := g .add (w , ctx , cancel )
41+ if ! ok {
42+ return
43+ }
44+ defer g .del (ctx )
45+
3546 h .ServeHTTP (w , r )
3647 })
3748}
3849
39- func (g * Grace ) isShuttingdown () bool {
40- g .mu .Lock ()
41- defer g .mu .Unlock ()
42- return g .shuttingDown
43- }
44-
45- func graceFromRequest (r * http.Request ) * Grace {
46- g , _ := r .Context ().Value (gracefulContextKey {}).(* Grace )
47- return g
48- }
50+ func (g * Grace ) add (w http.ResponseWriter , ctx context.Context , cancel context.CancelFunc ) bool {
51+ g .handlersMu .Lock ()
52+ defer g .handlersMu .Unlock ()
4953
50- func (g * Grace ) addConn (c * Conn ) error {
51- g .mu .Lock ()
52- defer g .mu .Unlock ()
53- if g .closed {
54- c .Close (StatusGoingAway , "server shutting down" )
55- return errors .New ("server shutting down" )
54+ if g .closing {
55+ http .Error (w , "shutting down" , http .StatusServiceUnavailable )
56+ return false
5657 }
57- if g .conns == nil {
58- g .conns = make (map [* Conn ]struct {})
58+
59+ if g .handlers == nil {
60+ g .handlers = make (map [context.Context ]context.CancelFunc )
5961 }
60- g .conns [c ] = struct {}{}
61- c .g = g
62- return nil
63- }
62+ g .handlers [ctx ] = cancel
6463
65- func (g * Grace ) delConn (c * Conn ) {
66- g .mu .Lock ()
67- defer g .mu .Unlock ()
68- delete (g .conns , c )
64+ return true
6965}
7066
71- type gracefulContextKey struct {}
67+ func (g * Grace ) del (ctx context.Context ) {
68+ g .handlersMu .Lock ()
69+ defer g .handlersMu .Unlock ()
70+
71+ delete (g .handlers , ctx )
72+ }
7273
7374// Close prevents the acceptance of new connections with
7475// http.StatusServiceUnavailable and closes all accepted
7576// connections with StatusGoingAway.
77+ //
78+ // Make sure to Close or Shutdown the *http.Server first.
7679func (g * Grace ) Close () error {
77- g .mu .Lock ()
78- g .shuttingDown = true
79- g .closed = true
80- var wg sync.WaitGroup
81- for c := range g .conns {
82- wg .Add (1 )
83- go func (c * Conn ) {
84- defer wg .Done ()
85- c .Close (StatusGoingAway , "server shutting down" )
86- }(c )
87-
88- delete (g .conns , c )
80+ g .handlersMu .Lock ()
81+ for _ , cancel := range g .handlers {
82+ cancel ()
8983 }
90- g .mu .Unlock ()
84+ g .handlersMu .Unlock ()
9185
92- wg .Wait ()
86+ // Wait for all goroutines to exit.
87+ g .Shutdown (context .Background ())
9388
9489 return nil
9590}
9691
9792// Shutdown prevents the acceptance of new connections and waits until
9893// all connections close. If the context is cancelled before that, it
9994// calls Close to close all connections immediately.
95+ //
96+ // Make sure to Close or Shutdown the *http.Server first.
10097func (g * Grace ) Shutdown (ctx context.Context ) error {
10198 defer g .Close ()
10299
103- g .mu .Lock ()
104- g .shuttingDown = true
105- g .mu .Unlock ()
106-
107100 // Same poll period used by net/http.
108101 t := time .NewTicker (500 * time .Millisecond )
109102 defer t .Stop ()
110103 for {
111- if g .zeroConns () {
104+ if g .zeroHandlers () {
112105 return nil
113106 }
114107
@@ -120,8 +113,8 @@ func (g *Grace) Shutdown(ctx context.Context) error {
120113 }
121114}
122115
123- func (g * Grace ) zeroConns () bool {
124- g .mu .Lock ()
125- defer g .mu .Unlock ()
126- return len (g .conns ) == 0
116+ func (g * Grace ) zeroHandlers () bool {
117+ g .handlersMu .Lock ()
118+ defer g .handlersMu .Unlock ()
119+ return len (g .handlers ) == 0
127120}
0 commit comments