Skip to content

Commit f01812d

Browse files
committed
Refactor redirect handling logic and add unit tests
1 parent c77c52d commit f01812d

File tree

2 files changed

+70
-59
lines changed

2 files changed

+70
-59
lines changed

redirecthandler/redirecthandler.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ func (r *RedirectHandler) WithRedirectHandling(client *http.Client) {
4747

4848
// checkRedirect implements the redirect handling logic.
4949
func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request) error {
50-
defer r.clearRedirectHistory(req) // Ensure redirect history is always cleared to prevent memory leaks
50+
51+
// Ensure redirect history is always cleared to prevent memory leaks
52+
defer r.clearRedirectHistory(req)
5153

5254
// Non-idempotent methods handling
5355
if req.Method == http.MethodPost || req.Method == http.MethodPatch {

redirecthandler/redirecthandler_test.go

Lines changed: 67 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,70 +14,79 @@ import (
1414
// It covers various scenarios including redirect loop detection, maximum redirects limit,
1515
// resolving relative redirects, cross-domain security measures, and handling of 303 See Other response.
1616
func TestRedirectHandler_CheckRedirect(t *testing.T) {
17-
mockLogger := mocklogger.NewMockLogger()
18-
19-
// Set the mock logger to capture logs at all levels
20-
mockLogger.SetLevel(logger.LogLevelDebug)
21-
22-
redirectHandler := NewRedirectHandler(mockLogger, 10)
17+
redirectHandler := NewRedirectHandler(nil, 10) // Logger is not needed for these tests
2318

2419
reqURL, _ := url.Parse("http://example.com")
2520
req := &http.Request{URL: reqURL, Method: http.MethodPost}
26-
resp := &http.Response{
27-
Status: "303 See Other",
28-
StatusCode: http.StatusSeeOther,
29-
Header: http.Header{"Location": []string{"http://example.com/new"}},
30-
}
31-
32-
t.Run("Redirect Loop Detection", func(t *testing.T) {
33-
redirectHandler.VisitedURLs = map[string]int{"http://example.com": 1}
34-
err := redirectHandler.checkRedirect(req, []*http.Request{{}, {}})
35-
assert.Equal(t, http.ErrUseLastResponse, err)
36-
// Verify that a warning log for redirect loop was recorded
37-
assert.Contains(t, mockLogger.Calls[0].Arguments.String(0), "Detected redirect loop")
38-
})
39-
40-
t.Run("Maximum Redirects Reached", func(t *testing.T) {
41-
redirectHandler.VisitedURLs = map[string]int{}
42-
redirectHandler.MaxRedirects = 1
43-
err := redirectHandler.checkRedirect(req, []*http.Request{{}, {}})
44-
assert.Equal(t, http.ErrUseLastResponse, err)
45-
// Verify that a warning log for max redirects was recorded
46-
assert.Contains(t, mockLogger.Calls[1].Arguments.String(0), "Stopped after maximum redirects")
47-
})
4821

49-
t.Run("Resolve Relative Redirects", func(t *testing.T) {
50-
redirectHandler.MaxRedirects = 10
51-
err := redirectHandler.checkRedirect(req, []*http.Request{{}, {}})
52-
assert.Nil(t, err)
53-
assert.Equal(t, "http://example.com/new", req.URL.String())
54-
})
22+
// Test cases
23+
tests := []struct {
24+
name string
25+
prepare func() *http.Response // Function to prepare the response for each test case
26+
expectedErr error
27+
expectedURL string
28+
}{
29+
{
30+
name: "Redirect Loop Detection",
31+
prepare: func() *http.Response {
32+
redirectHandler.VisitedURLs = map[string]int{"http://example.com": 1}
33+
return nil
34+
},
35+
expectedErr: http.ErrUseLastResponse,
36+
},
37+
{
38+
name: "Maximum Redirects Reached",
39+
prepare: func() *http.Response {
40+
redirectHandler.VisitedURLs = map[string]int{}
41+
redirectHandler.MaxRedirects = 1
42+
return nil
43+
},
44+
expectedErr: http.ErrUseLastResponse,
45+
},
46+
{
47+
name: "Resolve Relative Redirects",
48+
prepare: func() *http.Response {
49+
redirectHandler.MaxRedirects = 10
50+
return &http.Response{
51+
StatusCode: http.StatusSeeOther,
52+
Header: http.Header{"Location": []string{"http://example.com/new"}},
53+
}
54+
},
55+
expectedURL: "http://example.com/new",
56+
},
57+
{
58+
name: "Cross-Domain Security Measures",
59+
prepare: func() *http.Response {
60+
return &http.Response{
61+
Header: http.Header{"Location": []string{"http://anotherdomain.com/new"}},
62+
}
63+
},
64+
expectedErr: nil,
65+
},
66+
{
67+
name: "Handling 303 See Other",
68+
prepare: func() *http.Response {
69+
return &http.Response{
70+
StatusCode: http.StatusSeeOther,
71+
Header: http.Header{"Location": []string{"http://example.com/new"}},
72+
}
73+
},
74+
expectedErr: nil,
75+
expectedURL: "http://example.com/new",
76+
},
77+
}
5578

56-
t.Run("Cross-Domain Security Measures", func(t *testing.T) {
57-
reqURL, _ = url.Parse("http://example.com")
58-
req = &http.Request{URL: reqURL, Method: http.MethodPost}
59-
resp.Header.Set("Location", "http://anotherdomain.com/new")
60-
err := redirectHandler.checkRedirect(req, []*http.Request{{}, {}})
61-
assert.Nil(t, err)
62-
// Ensure sensitive headers are removed and corresponding log is recorded
63-
assert.Empty(t, req.Header.Get("Authorization"))
64-
assert.Contains(t, mockLogger.Calls[2].Arguments.String(0), "Removed sensitive header")
65-
})
79+
for _, tc := range tests {
80+
t.Run(tc.name, func(t *testing.T) {
81+
resp := tc.prepare()
82+
err := redirectHandler.checkRedirect(req, []*http.Request{{Response: resp}, {}})
6683

67-
t.Run("Handling 303 See Other", func(t *testing.T) {
68-
reqURL, _ = url.Parse("http://example.com")
69-
req = &http.Request{URL: reqURL, Method: http.MethodPost}
70-
resp.Header.Set("Location", "http://example.com/new")
71-
err := redirectHandler.checkRedirect(req, []*http.Request{{}, {}})
72-
assert.Nil(t, err)
73-
assert.Equal(t, http.MethodGet, req.Method)
74-
// Ensure no body, no GetBody, correct ContentLength, no Content-Type header, and a log is recorded
75-
assert.Nil(t, req.Body)
76-
assert.Nil(t, req.GetBody)
77-
assert.Equal(t, int64(0), req.ContentLength)
78-
assert.Empty(t, req.Header.Get("Content-Type"))
79-
assert.Contains(t, mockLogger.Calls[3].Arguments.String(0), "Changed request method to GET")
80-
})
84+
assert.Equal(t, tc.expectedErr, err)
85+
if tc.expectedURL != "" {
86+
assert.Equal(t, tc.expectedURL, req.URL.String())
87+
}
88+
})
89+
}
8190
}
8291

8392
// TestRedirectHandler_ResolveRedirectURL tests the resolveRedirectURL method of the RedirectHandler.

0 commit comments

Comments
 (0)