Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,32 @@ func Test_BaseController(t *testing.T) {
type testCase struct {
title string
method string
mws []Middleware
mws []func(w http.ResponseWriter) func (http.Handler) http.Handler
out string
}

cases := []testCase{
{
title: "register middleware for HTTP method",
method: http.MethodGet,
mws: []Middleware{middlewareOne},
out: "/mw1 before next/final handler/mw1 after next",
mws: []func(w http.ResponseWriter) func(http.Handler) http.Handler{
middlewareOne,
},
out: "/mw1 prepare/mw1 before next/final handler/mw1 after next",
},
{
title: "add middleware to existing chain",
method: http.MethodGet,
mws: []Middleware{middlewareTwo, middlewareThree},
out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next",
mws: []func(w http.ResponseWriter) func(http.Handler) http.Handler{
middlewareTwo, middlewareThree,
},
out: "/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next/final handler" +
"/mw3 after next/mw2 after next/mw1 after next",
},
{
title: "get an empty middleware chain (by default)",
method: http.MethodPost,
mws: []func(w http.ResponseWriter) func(http.Handler) http.Handler{},
out: "/final handler",
},
}
Expand All @@ -40,10 +46,14 @@ func Test_BaseController(t *testing.T) {
controller := NewBaseController()
for _, tc := range cases {
t.Run(tc.title, func(t *testing.T) {
w := httptest.NewRecorder()
if len(tc.mws) > 0 {
controller.AddMiddleware(tc.method, tc.mws...)
var mws []Middleware
for _, mw := range tc.mws {
mws = append(mws, mw(w))
}
controller.AddMiddleware(tc.method, mws...)
}
w := httptest.NewRecorder()
controller.Middleware(tc.method).Then(handlerFinal).ServeHTTP(w, nil)
if w.Body.String() != tc.out {
t.Errorf("handler output is expected to be %q but was %q", tc.out, w.Body.String())
Expand Down
7 changes: 6 additions & 1 deletion middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ func (mw Middleware) Use(middlewares ...Middleware) Middleware {
for _, next := range middlewares {
mw = func(curr, next Middleware) Middleware {
return func(handler http.Handler) http.Handler {
return curr(next(handler))
var nextHandler http.Handler
currHandler := curr(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextHandler.ServeHTTP(w, r)
}))
nextHandler = next(handler)
return currHandler
}
}(mw, next)
}
Expand Down
149 changes: 91 additions & 58 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,37 @@ import (
)

var (
middlewareOne = func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("/mw1 before next"))
next.ServeHTTP(w, r)
w.Write([]byte("/mw1 after next"))
})
middlewareOne = func(w http.ResponseWriter) func (http.Handler) http.Handler {
w.Write([]byte("/mw1 prepare"))
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("/mw1 before next"))
next.ServeHTTP(w, r)
w.Write([]byte("/mw1 after next"))
})
}
}

middlewareTwo = func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("/mw2 before next"))
next.ServeHTTP(w, r)
w.Write([]byte("/mw2 after next"))
})
middlewareTwo = func(w http.ResponseWriter) func (http.Handler) http.Handler {
w.Write([]byte("/mw2 prepare"))
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("/mw2 before next"))
next.ServeHTTP(w, r)
w.Write([]byte("/mw2 after next"))
})
}
}

middlewareThree = func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("/mw3 before next"))
next.ServeHTTP(w, r)
w.Write([]byte("/mw3 after next"))
})
middlewareThree = func(w http.ResponseWriter) func (http.Handler) http.Handler {
w.Write([]byte("/mw3 prepare"))
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("/mw3 before next"))
next.ServeHTTP(w, r)
w.Write([]byte("/mw3 after next"))
})
}
}

middlewareFuncOne = func(w http.ResponseWriter, r *http.Request, next http.Handler) {
Expand All @@ -43,10 +52,13 @@ var (
w.Write([]byte("/mw func2 after next"))
}

middlewareBreak Middleware = func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("/skip the rest"))
})
middlewareBreak = func(w http.ResponseWriter) Middleware {
w.Write([]byte("/skip the rest prepare"))
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("/skip the rest"))
})
}
}

handlerOne = func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -65,37 +77,48 @@ var (
func Test_Middleware(t *testing.T) {
type testCase struct {
title string
handler http.Handler
handler func (w http.ResponseWriter) http.Handler
out string
}

cases := []testCase{
{
title: "build handler with single middleware (one call of Use() func with single argument)",
handler: New().Use(middlewareOne).Then(handlerFinal),
out: "/mw1 before next/final handler/mw1 after next",
handler: func(w http.ResponseWriter) http.Handler {
return New().Use(middlewareOne(w)).Then(handlerFinal)
},
out: "/mw1 prepare/mw1 before next/final handler/mw1 after next",
},
{
title: "build handler passing middleware to the constructor (call New() with arguments)",
handler: New(middlewareOne, middlewareTwo).Use(middlewareThree).Then(handlerFinal),
out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next",
handler: func(w http.ResponseWriter) http.Handler {
return New(middlewareOne(w), middlewareTwo(w)).Use(middlewareThree(w)).Then(handlerFinal)
},
out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next" +
"/final handler/mw3 after next/mw2 after next/mw1 after next",
},
{
title: "build handler with multiple middleware (adding one middleware per Use())",
handler: New().Use(middlewareOne).Use(middlewareTwo).Use(middlewareThree).Then(handlerFinal),
out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next",
handler: func(w http.ResponseWriter) http.Handler {
return New().Use(middlewareOne(w)).Use(middlewareTwo(w)).Use(middlewareThree(w)).Then(handlerFinal)
},
out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next" +
"/final handler/mw3 after next/mw2 after next/mw1 after next",
},
{
title: "build handler with combination of single/plural calls of Use()",
handler: New().Use(middlewareOne).Use(middlewareTwo, middlewareThree).Then(handlerFinal),
out: "/mw1 before next/mw2 before next/mw3 before next/final handler/mw3 after next/mw2 after next/mw1 after next",
handler: func(w http.ResponseWriter) http.Handler {
return New().Use(middlewareOne(w)).Use(middlewareTwo(w), middlewareThree(w)).Then(handlerFinal)
},
out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw3 before next" +
"/final handler/mw3 after next/mw2 after next/mw1 after next",
},
}

for _, tc := range cases {
t.Run(tc.title, func(t *testing.T) {
w := httptest.NewRecorder()
tc.handler.ServeHTTP(w, nil)
tc.handler(w).ServeHTTP(w, nil)
if w.Body.String() != tc.out {
t.Errorf("the output %q is expected to be %q", w.Body.String(), tc.out)
}
Expand All @@ -111,52 +134,62 @@ func Test_Chain(t *testing.T) {

type testCase struct {
title string
args []interface{}
args func (http.ResponseWriter) []interface{}
out string
panic bool
}

cases := []testCase{
{
title: "building handler with unsupported argument types should panic",
args: []interface{}{
middlewareOne,
middlewareTwo,
true,
middlewareThree,
handlerFinal,
args: func(w http.ResponseWriter) []interface{} {
return []interface{}{
middlewareOne(w),
middlewareTwo(w),
true,
middlewareThree(w),
handlerFinal,
}
},
panic: true,
},
{
title: "middleware should have control over the \"next\" handlers",
args: []interface{}{
middlewareOne,
middlewareTwo,
middlewareBreak,
middlewareThree,
handlerFinal,
args: func(w http.ResponseWriter) []interface{} {
return []interface{}{
middlewareOne(w),
middlewareTwo(w),
middlewareBreak(w),
middlewareThree(w),
handlerFinal,
}
},
out: "/mw1 before next/mw2 before next/skip the rest/mw2 after next/mw1 after next",
out: "/mw1 prepare/mw2 prepare/skip the rest prepare/mw3 prepare/mw1 before next/mw2 before next" +
"/skip the rest/mw2 after next/mw1 after next",
},
{
title: "calling function without any arguments should build a middleware with only blobHandler",
args: func(w http.ResponseWriter) []interface{} {
return []interface{}{}
},
out: "/blob handler",
},
{
title: "building handler with all kind of supported arguments should be successful",
args: []interface{}{
middlewareOne,
Middleware(middlewareTwo),
middlewareFuncOne,
MiddlewareFunc(middlewareFuncTwo),
handlerOne,
http.HandlerFunc(handlerTwo),
middlewareThree,
handlerFinal,
args: func(w http.ResponseWriter) []interface{} {
return []interface{}{
middlewareOne(w),
Middleware(middlewareTwo(w)),
middlewareFuncOne,
MiddlewareFunc(middlewareFuncTwo),
handlerOne,
http.HandlerFunc(handlerTwo),
middlewareThree(w),
handlerFinal,
}
},
out: "/mw1 before next/mw2 before next/mw func1 before next/mw func2 before next" +
"/first handler/second handler/mw3 before next/final handler/blob handler" +
out: "/mw1 prepare/mw2 prepare/mw3 prepare/mw1 before next/mw2 before next/mw func1 before next" +
"/mw func2 before next/first handler/second handler/mw3 before next/final handler/blob handler" +
"/mw3 after next/mw func2 after next/mw func1 after next/mw2 after next/mw1 after next",
},
}
Expand All @@ -175,7 +208,7 @@ func Test_Chain(t *testing.T) {
}
}()
w := httptest.NewRecorder()
Chain(tc.args...).ServeHTTP(w, nil)
Chain(tc.args(w)...).ServeHTTP(w, nil)
if w.Body.String() != tc.out {
t.Errorf("out %v expected to be %v", w.Body.String(), tc.out)
}
Expand Down