@@ -6,14 +6,15 @@ import (
66 "bytes"
77 "crypto/sha1"
88 "encoding/base64"
9+ "errors"
10+ "fmt"
911 "io"
1012 "net/http"
1113 "net/textproto"
1214 "net/url"
15+ "strconv"
1316 "strings"
1417
15- "golang.org/x/xerrors"
16-
1718 "nhooyr.io/websocket/internal/errd"
1819)
1920
@@ -85,7 +86,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
8586
8687 hj , ok := w .(http.Hijacker )
8788 if ! ok {
88- err = xerrors .New ("http.ResponseWriter does not implement http.Hijacker" )
89+ err = errors .New ("http.ResponseWriter does not implement http.Hijacker" )
8990 http .Error (w , http .StatusText (http .StatusNotImplemented ), http .StatusNotImplemented )
9091 return nil , err
9192 }
@@ -110,7 +111,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
110111
111112 netConn , brw , err := hj .Hijack ()
112113 if err != nil {
113- err = xerrors .Errorf ("failed to hijack connection: %w" , err )
114+ err = fmt .Errorf ("failed to hijack connection: %w" , err )
114115 http .Error (w , http .StatusText (http .StatusInternalServerError ), http .StatusInternalServerError )
115116 return nil , err
116117 }
@@ -133,32 +134,32 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
133134
134135func verifyClientRequest (w http.ResponseWriter , r * http.Request ) (errCode int , _ error ) {
135136 if ! r .ProtoAtLeast (1 , 1 ) {
136- return http .StatusUpgradeRequired , xerrors .Errorf ("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q" , r .Proto )
137+ return http .StatusUpgradeRequired , fmt .Errorf ("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q" , r .Proto )
137138 }
138139
139140 if ! headerContainsToken (r .Header , "Connection" , "Upgrade" ) {
140141 w .Header ().Set ("Connection" , "Upgrade" )
141142 w .Header ().Set ("Upgrade" , "websocket" )
142- return http .StatusUpgradeRequired , xerrors .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
143+ return http .StatusUpgradeRequired , fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
143144 }
144145
145146 if ! headerContainsToken (r .Header , "Upgrade" , "websocket" ) {
146147 w .Header ().Set ("Connection" , "Upgrade" )
147148 w .Header ().Set ("Upgrade" , "websocket" )
148- return http .StatusUpgradeRequired , xerrors .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
149+ return http .StatusUpgradeRequired , fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
149150 }
150151
151152 if r .Method != "GET" {
152- return http .StatusMethodNotAllowed , xerrors .Errorf ("WebSocket protocol violation: handshake request method is not GET but %q" , r .Method )
153+ return http .StatusMethodNotAllowed , fmt .Errorf ("WebSocket protocol violation: handshake request method is not GET but %q" , r .Method )
153154 }
154155
155156 if r .Header .Get ("Sec-WebSocket-Version" ) != "13" {
156157 w .Header ().Set ("Sec-WebSocket-Version" , "13" )
157- return http .StatusBadRequest , xerrors .Errorf ("unsupported WebSocket protocol version (only 13 is supported): %q" , r .Header .Get ("Sec-WebSocket-Version" ))
158+ return http .StatusBadRequest , fmt .Errorf ("unsupported WebSocket protocol version (only 13 is supported): %q" , r .Header .Get ("Sec-WebSocket-Version" ))
158159 }
159160
160161 if r .Header .Get ("Sec-WebSocket-Key" ) == "" {
161- return http .StatusBadRequest , xerrors .New ("WebSocket protocol violation: missing Sec-WebSocket-Key" )
162+ return http .StatusBadRequest , errors .New ("WebSocket protocol violation: missing Sec-WebSocket-Key" )
162163 }
163164
164165 return 0 , nil
@@ -169,10 +170,10 @@ func authenticateOrigin(r *http.Request) error {
169170 if origin != "" {
170171 u , err := url .Parse (origin )
171172 if err != nil {
172- return xerrors .Errorf ("failed to parse Origin header %q: %w" , origin , err )
173+ return fmt .Errorf ("failed to parse Origin header %q: %w" , origin , err )
173174 }
174175 if ! strings .EqualFold (u .Host , r .Host ) {
175- return xerrors .Errorf ("request Origin %q is not authorized for Host %q" , origin , r .Host )
176+ return fmt .Errorf ("request Origin %q is not authorized for Host %q" , origin , r .Host )
176177 }
177178 }
178179 return nil
@@ -208,6 +209,7 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM
208209
209210func acceptDeflate (w http.ResponseWriter , ext websocketExtension , mode CompressionMode ) (* compressionOptions , error ) {
210211 copts := mode .opts ()
212+ copts .serverMaxWindowBits = 8
211213
212214 for _ , p := range ext .params {
213215 switch p {
@@ -219,11 +221,31 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
219221 continue
220222 }
221223
222- if strings .HasPrefix (p , "client_max_window_bits" ) || strings .HasPrefix (p , "server_max_window_bits" ) {
224+ if strings .HasPrefix (p , "client_max_window_bits" ) {
225+ continue
226+
227+ // bits, ok := parseExtensionParameter(p, 15)
228+ // if !ok || bits < 8 || bits > 16 {
229+ // err := fmt.Errorf("invalid client_max_window_bits: %q", p)
230+ // http.Error(w, err.Error(), http.StatusBadRequest)
231+ // return nil, err
232+ // }
233+ // copts.clientMaxWindowBits = bits
234+ // continue
235+ }
236+
237+ if false && strings .HasPrefix (p , "server_max_window_bits" ) {
238+ // We always send back 8 but make sure to validate.
239+ bits , ok := parseExtensionParameter (p , 0 )
240+ if ! ok || bits < 8 || bits > 16 {
241+ err := fmt .Errorf ("invalid server_max_window_bits: %q" , p )
242+ http .Error (w , err .Error (), http .StatusBadRequest )
243+ return nil , err
244+ }
223245 continue
224246 }
225247
226- err := xerrors .Errorf ("unsupported permessage-deflate parameter: %q" , p )
248+ err := fmt .Errorf ("unsupported permessage-deflate parameter: %q" , p )
227249 http .Error (w , err .Error (), http .StatusBadRequest )
228250 return nil , err
229251 }
@@ -233,6 +255,21 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
233255 return copts , nil
234256}
235257
258+ // parseExtensionParameter parses the value in the extension parameter p.
259+ // It falls back to defaultVal if there is no value.
260+ // If defaultVal == 0, then ok == false if there is no value.
261+ func parseExtensionParameter (p string , defaultVal int ) (int , bool ) {
262+ ps := strings .Split (p , "=" )
263+ if len (ps ) == 1 {
264+ if defaultVal > 0 {
265+ return defaultVal , true
266+ }
267+ return 0 , false
268+ }
269+ i , e := strconv .Atoi (strings .Trim (ps [1 ], `"` ))
270+ return i , e == nil
271+ }
272+
236273func acceptWebkitDeflate (w http.ResponseWriter , ext websocketExtension , mode CompressionMode ) (* compressionOptions , error ) {
237274 copts := mode .opts ()
238275 // The peer must explicitly request it.
@@ -253,7 +290,7 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com
253290 //
254291 // Either way, we're only implementing this for webkit which never sends the max_window_bits
255292 // parameter so we don't need to worry about it.
256- err := xerrors .Errorf ("unsupported x-webkit-deflate-frame parameter: %q" , p )
293+ err := fmt .Errorf ("unsupported x-webkit-deflate-frame parameter: %q" , p )
257294 http .Error (w , err .Error (), http .StatusBadRequest )
258295 return nil , err
259296 }
0 commit comments