@@ -2,6 +2,7 @@ package main
22
33import (
44 "context"
5+ "errors"
56 "io"
67 "io/ioutil"
78 "log"
@@ -12,21 +13,36 @@ import (
1213 "nhooyr.io/websocket"
1314)
1415
16+ // chatServer enables broadcasting to a set of subscribers.
1517type chatServer struct {
1618 subscribersMu sync.RWMutex
17- subscribers map [chan []byte ]struct {}
19+ subscribers map [chan <- []byte ]struct {}
1820}
1921
22+ // subscribeHandler accepts the WebSocket connection and then subscribes
23+ // it to all future messages.
2024func (cs * chatServer ) subscribeHandler (w http.ResponseWriter , r * http.Request ) {
2125 c , err := websocket .Accept (w , r , nil )
2226 if err != nil {
2327 log .Print (err )
2428 return
2529 }
2630
27- cs .subscribe (r .Context (), c )
31+ err = cs .subscribe (r .Context (), c )
32+ if errors .Is (err , context .Canceled ) {
33+ return
34+ }
35+ if websocket .CloseStatus (err ) == websocket .StatusNormalClosure ||
36+ websocket .CloseStatus (err ) == websocket .StatusGoingAway {
37+ return
38+ }
39+ if err != nil {
40+ log .Print (err )
41+ }
2842}
2943
44+ // publishHandler reads the request body with a limit of 8192 bytes and then publishes
45+ // the received message.
3046func (cs * chatServer ) publishHandler (w http.ResponseWriter , r * http.Request ) {
3147 if r .Method != "POST" {
3248 http .Error (w , http .StatusText (http .StatusMethodNotAllowed ), http .StatusMethodNotAllowed )
@@ -35,12 +51,44 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
3551 body := io .LimitReader (r .Body , 8192 )
3652 msg , err := ioutil .ReadAll (body )
3753 if err != nil {
54+ http .Error (w , http .StatusText (http .StatusRequestEntityTooLarge ), http .StatusRequestEntityTooLarge )
3855 return
3956 }
4057
4158 cs .publish (msg )
4259}
4360
61+ // subscribe subscribes the given WebSocket to all broadcast messages.
62+ // It creates a msgs chan with a buffer of 16 to give some room to slower
63+ // connections and then registers it. It then listens for all messages
64+ // and writes them to the WebSocket. If the context is cancelled or
65+ // an error occurs, it returns and deletes the subscription.
66+ //
67+ // It uses CloseRead to keep reading from the connection to process control
68+ // messages and cancel the context if the connection drops.
69+ func (cs * chatServer ) subscribe (ctx context.Context , c * websocket.Conn ) error {
70+ ctx = c .CloseRead (ctx )
71+
72+ msgs := make (chan []byte , 16 )
73+ cs .addSubscriber (msgs )
74+ defer cs .deleteSubscriber (msgs )
75+
76+ for {
77+ select {
78+ case msg := <- msgs :
79+ err := writeTimeout (ctx , time .Second * 5 , c , msg )
80+ if err != nil {
81+ return err
82+ }
83+ case <- ctx .Done ():
84+ return ctx .Err ()
85+ }
86+ }
87+ }
88+
89+ // publish publishes the msg to all subscribers.
90+ // It never blocks and so messages to slow subscribers
91+ // are dropped.
4492func (cs * chatServer ) publish (msg []byte ) {
4593 cs .subscribersMu .RLock ()
4694 defer cs .subscribersMu .RUnlock ()
@@ -53,41 +101,24 @@ func (cs *chatServer) publish(msg []byte) {
53101 }
54102}
55103
56- func (cs * chatServer ) addSubscriber (msgs chan []byte ) {
104+ // addSubscriber registers a subscriber with a channel
105+ // on which to send messages.
106+ func (cs * chatServer ) addSubscriber (msgs chan <- []byte ) {
57107 cs .subscribersMu .Lock ()
58108 if cs .subscribers == nil {
59- cs .subscribers = make (map [chan []byte ]struct {})
109+ cs .subscribers = make (map [chan <- []byte ]struct {})
60110 }
61111 cs .subscribers [msgs ] = struct {}{}
62112 cs .subscribersMu .Unlock ()
63113}
64114
115+ // deleteSubscriber deletes the subscriber with the given msgs channel.
65116func (cs * chatServer ) deleteSubscriber (msgs chan []byte ) {
66117 cs .subscribersMu .Lock ()
67118 delete (cs .subscribers , msgs )
68119 cs .subscribersMu .Unlock ()
69120}
70121
71- func (cs * chatServer ) subscribe (ctx context.Context , c * websocket.Conn ) error {
72- ctx = c .CloseRead (ctx )
73-
74- msgs := make (chan []byte , 16 )
75- cs .addSubscriber (msgs )
76- defer cs .deleteSubscriber (msgs )
77-
78- for {
79- select {
80- case msg := <- msgs :
81- err := writeTimeout (ctx , time .Second * 5 , c , msg )
82- if err != nil {
83- return err
84- }
85- case <- ctx .Done ():
86- return ctx .Err ()
87- }
88- }
89- }
90-
91122func writeTimeout (ctx context.Context , timeout time.Duration , c * websocket.Conn , msg []byte ) error {
92123 ctx , cancel := context .WithTimeout (ctx , timeout )
93124 defer cancel ()
0 commit comments