@@ -3,25 +3,57 @@ package main
33import (
44 "context"
55 "errors"
6- "io"
76 "io/ioutil"
87 "log"
98 "net/http"
109 "sync"
1110 "time"
1211
12+ "golang.org/x/time/rate"
13+
1314 "nhooyr.io/websocket"
1415)
1516
1617// chatServer enables broadcasting to a set of subscribers.
1718type chatServer struct {
18- registerOnce sync.Once
19- m http.ServeMux
20-
21- subscribersMu sync.RWMutex
19+ // subscriberMessageBuffer controls the max number
20+ // of messages that can be queued for a subscriber
21+ // before it is kicked.
22+ //
23+ // Defaults to 16.
24+ subscriberMessageBuffer int
25+
26+ // publishLimiter controls the rate limit applied to the publish endpoint.
27+ //
28+ // Defaults to one publish every 100ms with a burst of 8.
29+ publishLimiter * rate.Limiter
30+
31+ // logf controls where logs are sent.
32+ // Defaults to log.Printf.
33+ logf func (f string , v ... interface {})
34+
35+ // serveMux routes the various endpoints to the appropriate handler.
36+ serveMux http.ServeMux
37+
38+ subscribersMu sync.Mutex
2239 subscribers map [* subscriber ]struct {}
2340}
2441
42+ // newChatServer constructs a chatServer with the defaults.
43+ func newChatServer () * chatServer {
44+ cs := & chatServer {
45+ subscriberMessageBuffer : 16 ,
46+ logf : log .Printf ,
47+ subscribers : make (map [* subscriber ]struct {}),
48+ publishLimiter : rate .NewLimiter (rate .Every (time .Millisecond * 100 ), 8 ),
49+ }
50+ cs .serveMux .Handle ("/" , http .FileServer (http .Dir ("." )))
51+ cs .serveMux .HandleFunc ("/subscribe" , cs .subscribeHandler )
52+ cs .serveMux .HandleFunc ("/publish" , cs .publishHandler )
53+
54+ return cs
55+ }
56+
2557// subscriber represents a subscriber.
2658// Messages are sent on the msgs channel and if the client
2759// cannot keep up with the messages, closeSlow is called.
@@ -31,20 +63,15 @@ type subscriber struct {
3163}
3264
3365func (cs * chatServer ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
34- cs .registerOnce .Do (func () {
35- cs .m .Handle ("/" , http .FileServer (http .Dir ("." )))
36- cs .m .HandleFunc ("/subscribe" , cs .subscribeHandler )
37- cs .m .HandleFunc ("/publish" , cs .publishHandler )
38- })
39- cs .m .ServeHTTP (w , r )
66+ cs .serveMux .ServeHTTP (w , r )
4067}
4168
4269// subscribeHandler accepts the WebSocket connection and then subscribes
4370// it to all future messages.
4471func (cs * chatServer ) subscribeHandler (w http.ResponseWriter , r * http.Request ) {
4572 c , err := websocket .Accept (w , r , nil )
4673 if err != nil {
47- log . Print ( err )
74+ cs . logf ( "%v" , err )
4875 return
4976 }
5077 defer c .Close (websocket .StatusInternalError , "" )
@@ -58,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
5885 return
5986 }
6087 if err != nil {
61- log .Print (err )
88+ cs .logf ("%v" , err )
89+ return
6290 }
6391}
6492
@@ -69,7 +97,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
6997 http .Error (w , http .StatusText (http .StatusMethodNotAllowed ), http .StatusMethodNotAllowed )
7098 return
7199 }
72- body := io . LimitReader ( r .Body , 8192 )
100+ body := http . MaxBytesReader ( w , r .Body , 8192 )
73101 msg , err := ioutil .ReadAll (body )
74102 if err != nil {
75103 http .Error (w , http .StatusText (http .StatusRequestEntityTooLarge ), http .StatusRequestEntityTooLarge )
@@ -93,7 +121,7 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
93121 ctx = c .CloseRead (ctx )
94122
95123 s := & subscriber {
96- msgs : make (chan []byte , 16 ),
124+ msgs : make (chan []byte , cs . subscriberMessageBuffer ),
97125 closeSlow : func () {
98126 c .Close (websocket .StatusPolicyViolation , "connection too slow to keep up with messages" )
99127 },
@@ -118,8 +146,10 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
118146// It never blocks and so messages to slow subscribers
119147// are dropped.
120148func (cs * chatServer ) publish (msg []byte ) {
121- cs .subscribersMu .RLock ()
122- defer cs .subscribersMu .RUnlock ()
149+ cs .subscribersMu .Lock ()
150+ defer cs .subscribersMu .Unlock ()
151+
152+ cs .publishLimiter .Wait (context .Background ())
123153
124154 for s := range cs .subscribers {
125155 select {
@@ -133,9 +163,6 @@ func (cs *chatServer) publish(msg []byte) {
133163// addSubscriber registers a subscriber.
134164func (cs * chatServer ) addSubscriber (s * subscriber ) {
135165 cs .subscribersMu .Lock ()
136- if cs .subscribers == nil {
137- cs .subscribers = make (map [* subscriber ]struct {})
138- }
139166 cs .subscribers [s ] = struct {}{}
140167 cs .subscribersMu .Unlock ()
141168}
0 commit comments