Skip to content

Commit 99150cc

Browse files
committed
feat: support middleware
1 parent ac7bd4d commit 99150cc

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed

lightbug_http/middleware.mojo

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
struct Context:
2+
var request: Request
3+
var params: Dict[String, AnyType]
4+
5+
func __init__(request: Request):
6+
self.request = request
7+
self.params = Dict[String, AnyType]()
8+
9+
trait Middleware:
10+
var next: Middleware
11+
12+
func call(context: Context) -> Response:
13+
...
14+
15+
struct ErrorMiddleware(Middleware):
16+
func call(context: Context) -> Response:
17+
do:
18+
return next.call(context: context)
19+
catch e: Exception:
20+
return InternalServerError()
21+
22+
struct LoggerMiddleware(Middleware):
23+
func call(context: Context) -> Response:
24+
print("Request: \(context.request)")
25+
return next.call(context: context)
26+
27+
struct RouterMiddleware(Middleware):
28+
var routes: Dict[String, Middleware]
29+
30+
func __init__():
31+
self.routes = Dict[String, Middleware]()
32+
33+
func add(route: String, middleware: Middleware):
34+
routes[route] = middleware
35+
36+
func call(context: Context) -> Response:
37+
# TODO: create a more advanced router
38+
39+
var route = context.request.path
40+
if middleware = routes[route]:
41+
return middleware.call(context: context)
42+
else:
43+
return NotFound()
44+
45+
struct StaticMiddleware(Middleware):
46+
var path: String
47+
48+
funct __init__(path: String):
49+
self.path = path
50+
51+
func call(context: Context) -> Response:
52+
var file = File(path: path + context.request.path)
53+
if file.exists:
54+
return FileResponse(file: file)
55+
else:
56+
return next.call(context: context)
57+
58+
struct CorsMiddleware(Middleware):
59+
func call(context: Context) -> Response:
60+
var response = next.call(context: context)
61+
response.headers["Access-Control-Allow-Origin"] = "*"
62+
return response
63+
64+
struct CompressionMiddleware(Middleware):
65+
func call(context: Context) -> Response:
66+
var response = next.call(context: context)
67+
response.body = compress(response.body)
68+
return response
69+
70+
struct SessionMiddleware(Middleware):
71+
var session: Session
72+
73+
func call(context: Context) -> Response:
74+
var request = context.request
75+
var response = context.response
76+
var session = session.load(request)
77+
context.params["session"] = session
78+
response = next.call(context: context)
79+
session.save(response)
80+
return response
81+
82+
struct BasicAuthMiddleware(Middleware):
83+
var username: String
84+
var password: String
85+
86+
func __init__(username: String, password: String):
87+
self.username = username
88+
self.password = password
89+
90+
func call(context: Context) -> Response:
91+
var request = context.request
92+
var auth = request.headers["Authorization"]
93+
if auth == "Basic \(username):\(password)":
94+
return next.call(context: context)
95+
else:
96+
return Unauthorized()
97+
98+
struct MiddlewareChain:
99+
var middlewares: Array[Middleware]
100+
101+
func __init__():
102+
self.middlewares = Array[Middleware]()
103+
104+
func add(middleware: Middleware):
105+
if middlewares.count == 0:
106+
middlewares.append(middleware)
107+
else:
108+
var last = middlewares[middlewares.count - 1]
109+
last.next = middleware
110+
middlewares.append(middleware)
111+
112+
func execute(request: Request) -> Response:
113+
var context = Context(request: request, response: response)
114+
if middlewares.count > 0:
115+
return middlewares[0].call(context: context)
116+
else:
117+
return NotFound()
118+
119+
fn OK(body: Bytes) -> HTTPResponse:
120+
return OK(body, String("text/plain"))
121+
122+
fn OK(body: Bytes, content_type: String) -> HTTPResponse:
123+
return HTTPResponse(
124+
ResponseHeader(True, 200, String("OK").as_bytes(), content_type.as_bytes()),
125+
body,
126+
)
127+
128+
fn NotFound(body: Bytes) -> HTTPResponse:
129+
return NotFoundResponse(body, String("text/plain"))
130+
131+
fn NotFound(body: Bytes, content_type: String) -> HTTPResponse:
132+
return HTTPResponse(
133+
ResponseHeader(True, 404, String("Not Found").as_bytes(), content_type.as_bytes()),
134+
body,
135+
)
136+
137+
fn InternalServerError(body: Bytes) -> HTTPResponse:
138+
return InternalServerErrorResponse(body, String("text/plain"))
139+
140+
fn InternalServerError(body: Bytes, content_type: String) -> HTTPResponse:
141+
return HTTPResponse(
142+
ResponseHeader(True, 500, String("Internal Server Error").as_bytes(), content_type.as_bytes()),
143+
body,
144+
)
145+
146+
fn Unauthorized(body: Bytes) -> HTTPResponse:
147+
return UnauthorizedResponse(body, String("text/plain"))
148+
149+
fn Unauthorized(body: Bytes, content_type: String) -> HTTPResponse:
150+
return HTTPResponse(
151+
ResponseHeader(True, 401, String("Unauthorized").as_bytes(), content_type.as_bytes()),
152+
body,
153+
)

0 commit comments

Comments
 (0)