diff --git a/client.go b/client.go index 00917ea3..ff5b3d48 100644 --- a/client.go +++ b/client.go @@ -125,6 +125,11 @@ type Dialer struct { // If Jar is nil, cookies are not sent in requests and ignored // in responses. Jar http.CookieJar + + // MaxErrorBodySize specifies the maximum size of the error response buffer + // in the case of a bad hanshake. If zero, a defaut max size buffer of 1024 bytes + // is used. + MaxErrorBodySize int } // Dial creates a new client connection by calling DialContext with a background context. @@ -364,9 +369,13 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h // Before closing the network connection on return from this // function, slurp up some of the response to aid application // debugging. - buf := make([]byte, 1024) - n, _ := io.ReadFull(resp.Body, buf) - resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) + bufSize := 1024 + if d.MaxErrorBodySize > 0 { + bufSize = d.MaxErrorBodySize + } + limReader := io.LimitReader(resp.Body, int64(bufSize)) + buf, _ := io.ReadAll(limReader) + resp.Body = io.NopCloser(bytes.NewReader(buf)) return nil, resp, ErrBadHandshake } diff --git a/client_server_test.go b/client_server_test.go index e4546aea..d00ca0e4 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -575,34 +575,49 @@ func TestHandshake(t *testing.T) { func TestRespOnBadHandshake(t *testing.T) { const expectedStatus = http.StatusGone const expectedBody = "This is the response body." + const maxErrorResponseSize = 4096 + tests := []struct { + name string + body []byte + lenMax int + }{ + {"SmallerThanLimit", []byte(expectedBody), 1024}, // default value when MaxErrorBodySize is not set + {"LargerThanLimit", make([]byte, maxErrorResponseSize+100), maxErrorResponseSize}, + {"ExactLimit", make([]byte, maxErrorResponseSize), maxErrorResponseSize}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(expectedStatus) + _, _ = w.Write(tt.body) + })) + defer s.Close() + + ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(expectedStatus) - _, _ = io.WriteString(w, expectedBody) - })) - defer s.Close() - - ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) - if err == nil { - ws.Close() - t.Fatalf("Dial: nil") - } + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } - if resp == nil { - t.Fatalf("resp=nil, err=%v", err) - } + if resp.StatusCode != expectedStatus { + t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) + } - if resp.StatusCode != expectedStatus { - t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) - } + p, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadFull(resp.Body) returned error %v", err) + } - p, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ReadFull(resp.Body) returned error %v", err) - } + if len(p) > tt.lenMax { + t.Fatalf("body size=%d, want <= %d", len(p), tt.lenMax) + } + }) - if string(p) != expectedBody { - t.Errorf("resp.Body=%s, want %s", p, expectedBody) } }