From ceff91b96bfb8025ee71b746290c315a977d47a8 Mon Sep 17 00:00:00 2001 From: Benoit <90827157+Yurhigz@users.noreply.github.com> Date: Wed, 15 Oct 2025 23:01:04 +0200 Subject: [PATCH 1/4] Limit error response size in WebSocket handshake Implement a limit on the error response size for WebSocket handshakes. Signed-off-by: Benoit <90827157+Yurhigz@users.noreply.github.com> --- client.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 00917ea3..f35f46d2 100644 --- a/client.go +++ b/client.go @@ -172,6 +172,7 @@ var nilDialer = *DefaultDialer // non-nil *http.Response so that callers can handle redirects, authentication, // etcetera. The response body may not contain the entire response and does not // need to be closed by the application. +var maxErrorResponseSize = 4096 func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { if d == nil { d = &nilDialer @@ -364,9 +365,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h // Before closing the network connection on return from this // function, slurp up some of the response to aid application // debugging. - buf := make([]byte, 1024) - n, _ := io.ReadFull(resp.Body, buf) - resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) + + // Mon implémentation avec une maxErrorResponseSize + limReader := io.LimitReader(resp.Body, int64(maxErrorResponseSize)) + buf, err := io.ReadAll(limReader) + if err != nil && err != io.EOF { + buf = []byte{} + } + resp.Body = io.NopCloser(bytes.NewReader(buf)) return nil, resp, ErrBadHandshake } From ffd1fbe8a4b621ef619d41e9ba6a258070fbaf05 Mon Sep 17 00:00:00 2001 From: "Benoit.G" Date: Sat, 18 Oct 2025 11:47:36 +0200 Subject: [PATCH 2/4] Enhancement: Patch to increase err resp size --- client.go | 2 +- client_server_test.go | 72 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index f35f46d2..c584de9a 100644 --- a/client.go +++ b/client.go @@ -173,6 +173,7 @@ var nilDialer = *DefaultDialer // etcetera. The response body may not contain the entire response and does not // need to be closed by the application. var maxErrorResponseSize = 4096 + func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { if d == nil { d = &nilDialer @@ -366,7 +367,6 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h // function, slurp up some of the response to aid application // debugging. - // Mon implémentation avec une maxErrorResponseSize limReader := io.LimitReader(resp.Body, int64(maxErrorResponseSize)) buf, err := io.ReadAll(limReader) if err != nil && err != io.EOF { diff --git a/client_server_test.go b/client_server_test.go index e4546aea..52c4f765 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -573,8 +573,10 @@ func TestHandshake(t *testing.T) { } func TestRespOnBadHandshake(t *testing.T) { + // Test Body smaller than maxErrorResponseSize. const expectedStatus = http.StatusGone const expectedBody = "This is the response body." + const maxErrorResponseSize = 4096 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(expectedStatus) @@ -604,6 +606,76 @@ func TestRespOnBadHandshake(t *testing.T) { if string(p) != expectedBody { t.Errorf("resp.Body=%s, want %s", p, expectedBody) } + + // Test Body larger than maxErrorResponseSize. + t.Run("ErrorResponseSizeLimited", func(t *testing.T) { + largeBody := make([]byte, maxErrorResponseSize+100) + for i := range largeBody { + largeBody[i] = 'a' + } + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write(largeBody) + })) + defer s.Close() + + ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: expected error, got nil") + } + + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } + + p, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll(resp.Body) returned error %v", err) + } + + resp.Body.Close() + + if len(p) > maxErrorResponseSize { + t.Fatalf("body size=%d, want <= %d", len(p), maxErrorResponseSize) + } + }) + + // Test Body exactly maxErrorResponseSize. + t.Run("ErrorResponseSizeExactLimit", func(t *testing.T) { + limitedBody := make([]byte, maxErrorResponseSize) + for i := range limitedBody { + limitedBody[i] = 'a' + } + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + w.Write(limitedBody) + })) + defer s.Close() + + ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: expected error, got nil") + } + + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } + + p, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll(resp.Body) returned error %v", err) + } + + resp.Body.Close() + + if len(p) != maxErrorResponseSize { + t.Fatalf("body size=%d, want %d", len(p), maxErrorResponseSize) + } + }) } type testLogWriter struct { From df1e3272d15254ecb496d9bca7c68d10f059bb34 Mon Sep 17 00:00:00 2001 From: "Benoit.G" Date: Sat, 25 Oct 2025 23:38:53 +0200 Subject: [PATCH 3/4] Hotfix: MaxErrorBodySize Struct field --- client.go | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index c584de9a..ff5b3d48 100644 --- a/client.go +++ b/client.go @@ -125,6 +125,11 @@ type Dialer struct { // If Jar is nil, cookies are not sent in requests and ignored // in responses. Jar http.CookieJar + + // MaxErrorBodySize specifies the maximum size of the error response buffer + // in the case of a bad hanshake. If zero, a defaut max size buffer of 1024 bytes + // is used. + MaxErrorBodySize int } // Dial creates a new client connection by calling DialContext with a background context. @@ -172,8 +177,6 @@ var nilDialer = *DefaultDialer // non-nil *http.Response so that callers can handle redirects, authentication, // etcetera. The response body may not contain the entire response and does not // need to be closed by the application. -var maxErrorResponseSize = 4096 - func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { if d == nil { d = &nilDialer @@ -366,12 +369,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h // Before closing the network connection on return from this // function, slurp up some of the response to aid application // debugging. - - limReader := io.LimitReader(resp.Body, int64(maxErrorResponseSize)) - buf, err := io.ReadAll(limReader) - if err != nil && err != io.EOF { - buf = []byte{} + bufSize := 1024 + if d.MaxErrorBodySize > 0 { + bufSize = d.MaxErrorBodySize } + limReader := io.LimitReader(resp.Body, int64(bufSize)) + buf, _ := io.ReadAll(limReader) resp.Body = io.NopCloser(bytes.NewReader(buf)) return nil, resp, ErrBadHandshake } From 6bf72129056f3607204c2d3642cadf42fccb4960 Mon Sep 17 00:00:00 2001 From: "Benoit.G" Date: Sun, 26 Oct 2025 00:16:51 +0200 Subject: [PATCH 4/4] Tests adapted and refactored --- client_server_test.go | 131 ++++++++++++------------------------------ 1 file changed, 37 insertions(+), 94 deletions(-) diff --git a/client_server_test.go b/client_server_test.go index 52c4f765..d00ca0e4 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -573,109 +573,52 @@ func TestHandshake(t *testing.T) { } func TestRespOnBadHandshake(t *testing.T) { - // Test Body smaller than maxErrorResponseSize. const expectedStatus = http.StatusGone const expectedBody = "This is the response body." const maxErrorResponseSize = 4096 + tests := []struct { + name string + body []byte + lenMax int + }{ + {"SmallerThanLimit", []byte(expectedBody), 1024}, // default value when MaxErrorBodySize is not set + {"LargerThanLimit", make([]byte, maxErrorResponseSize+100), maxErrorResponseSize}, + {"ExactLimit", make([]byte, maxErrorResponseSize), maxErrorResponseSize}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(expectedStatus) + _, _ = w.Write(tt.body) + })) + defer s.Close() + + ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(expectedStatus) - _, _ = io.WriteString(w, expectedBody) - })) - defer s.Close() - - ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) - if err == nil { - ws.Close() - t.Fatalf("Dial: nil") - } + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } - if resp == nil { - t.Fatalf("resp=nil, err=%v", err) - } + if resp.StatusCode != expectedStatus { + t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) + } - if resp.StatusCode != expectedStatus { - t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) - } + p, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadFull(resp.Body) returned error %v", err) + } - p, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ReadFull(resp.Body) returned error %v", err) - } + if len(p) > tt.lenMax { + t.Fatalf("body size=%d, want <= %d", len(p), tt.lenMax) + } + }) - if string(p) != expectedBody { - t.Errorf("resp.Body=%s, want %s", p, expectedBody) } - - // Test Body larger than maxErrorResponseSize. - t.Run("ErrorResponseSizeLimited", func(t *testing.T) { - largeBody := make([]byte, maxErrorResponseSize+100) - for i := range largeBody { - largeBody[i] = 'a' - } - - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - w.Write(largeBody) - })) - defer s.Close() - - ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) - if err == nil { - ws.Close() - t.Fatalf("Dial: expected error, got nil") - } - - if resp == nil { - t.Fatalf("resp=nil, err=%v", err) - } - - p, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ReadAll(resp.Body) returned error %v", err) - } - - resp.Body.Close() - - if len(p) > maxErrorResponseSize { - t.Fatalf("body size=%d, want <= %d", len(p), maxErrorResponseSize) - } - }) - - // Test Body exactly maxErrorResponseSize. - t.Run("ErrorResponseSizeExactLimit", func(t *testing.T) { - limitedBody := make([]byte, maxErrorResponseSize) - for i := range limitedBody { - limitedBody[i] = 'a' - } - - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadGateway) - w.Write(limitedBody) - })) - defer s.Close() - - ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) - if err == nil { - ws.Close() - t.Fatalf("Dial: expected error, got nil") - } - - if resp == nil { - t.Fatalf("resp=nil, err=%v", err) - } - - p, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ReadAll(resp.Body) returned error %v", err) - } - - resp.Body.Close() - - if len(p) != maxErrorResponseSize { - t.Fatalf("body size=%d, want %d", len(p), maxErrorResponseSize) - } - }) } type testLogWriter struct {