1- from lightbug_http.http import HTTPRequest, HTTPResponse
1+ from lightbug_http.http import *
2+ from lightbug_http.service import HTTPService
23
4+ @value
35struct Context :
4- var request : Request
5- var params : Dict[String, AnyType ]
6+ var request : HTTPRequest
7+ var params : Dict[String, String ]
68
7- fn __init__ (self , request : Request ):
9+ fn __init__ (inout self , request : HTTPRequest ):
810 self .request = request
9- self .params = Dict[String, AnyType ]()
11+ self .params = Dict[String, String ]()
1012
1113trait Middleware :
12- var next : Middleware
13-
14- fn call (self , context : Context) -> Response:
14+ fn call (self , context : Context) -> HTTPResponse:
1515 ...
1616
1717struct ErrorMiddleware (Middleware ):
18- fn call (self , context : Context) -> Response:
18+ var next : Middleware
19+
20+ fn call (inout self , context : Context) -> HTTPResponse:
1921 try :
20- return next .call(context: context )
22+ return next .call(context)
2123 catch e: Exception :
2224 return InternalServerError()
2325
2426struct LoggerMiddleware (Middleware ):
25- fn call (self , context : Context) -> Response:
26- print (" Request: \(context.request)" )
27- return next .call(context: context)
27+ var next : Middleware
28+
29+ fn call (self , context : Context) -> HTTPResponse:
30+ print (f " Request: { context.request} " )
31+ return next .call(context)
2832
2933struct StaticMiddleware (Middleware ):
34+ var next : Middleware
3035 var path : String
3136
32- fnt __init__ (self , path: String):
37+ fn __init__ (self , path : String):
3338 self .path = path
3439
35- fn call (self , context : Context) -> Response :
36- if context.request.path == " /" :
40+ fn call (self , context : Context) -> HTTPResponse :
41+ if context.request.uri(). path() == " /" :
3742 var file = File(path: path + " index.html" )
3843 else :
39- var file = File(path: path + context.request.path)
44+ var file = File(path: path + context.request.uri(). path() )
4045
4146 if file .exists:
4247 var html : String
4348 with open (file , " r" ) as f:
4449 html = f.read()
4550 return OK(html.as_bytes(), " text/html" )
4651 else :
47- return next .call(context: context )
52+ return next .call(context)
4853
4954struct CorsMiddleware (Middleware ):
55+ var next : Middleware
5056 var allow_origin : String
5157
5258 fn __init__ (self , allow_origin : String):
5359 self .allow_origin = allow_origin
5460
55- fn call (self , context : Context) -> Response :
56- if context.request.method == " OPTIONS" :
57- var response = next .call(context: context )
61+ fn call (self , context : Context) -> HTTPResponse :
62+ if context.request.header. method() == " OPTIONS" :
63+ var response = next .call(context)
5864 response.headers[" Access-Control-Allow-Origin" ] = allow_origin
5965 response.headers[" Access-Control-Allow-Methods" ] = " GET, POST, PUT, DELETE, OPTIONS"
6066 response.headers[" Access-Control-Allow-Headers" ] = " Content-Type, Authorization"
6167 return response
6268
6369 if context.request.origin == allow_origin:
64- return next .call(context: context )
70+ return next .call(context)
6571 else :
6672 return Unauthorized()
6773
6874struct CompressionMiddleware (Middleware ):
69- fn call (self , context : Context) -> Response:
70- var response = next .call(context: context)
75+ var next : Middleware
76+ fn call (self , context : Context) -> HTTPResponse:
77+ var response = next .call(context)
7178 response.body = compress(response.body)
7279 return response
7380
7481 fn compress (self , body : Bytes) -> Bytes:
7582 # TODO : implement compression
7683 return body
7784
78-
7985struct RouterMiddleware (Middleware ):
86+ var next : Middleware
8087 var routes : Dict[String, Middleware]
8188
82- fn __init__ (self ):
89+ fn __init__ (inout self ):
8390 self .routes = Dict[String, Middleware]()
8491
8592 fn add (self , method : String, route : String, middleware : Middleware):
86- routes[method + " :" + route] = middleware
93+ self . routes[method + " :" + route] = middleware
8794
88- fn call (self , context : Context) -> Response :
95+ fn call (self , context : Context) -> HTTPResponse :
8996 # TODO : create a more advanced router
90- var method = context.request.method
91- var route = context.request.path
92- if middleware = routes[method + " :" + route]:
93- return middleware.call(context: context)
97+ var method = context.request.header.method()
98+ var route = context.request.uri().path()
99+ var middleware = self .routes.find(method + " :" + route)
100+ if middleware:
101+ return middleware.value().call(context)
94102 else :
95- return next .call(context: context )
103+ return next .call(context)
96104
97105struct BasicAuthMiddleware (Middleware ):
106+ var next : Middleware
98107 var username : String
99108 var password : String
100109
101110 fn __init__ (self , username : String, password : String):
102111 self .username = username
103112 self .password = password
104113
105- fn call (self , context : Context) -> Response :
114+ fn call (self , context : Context) -> HTTPResponse :
106115 var request = context.request
107116 var auth = request.headers[" Authorization" ]
108- if auth == " Basic \( username):\( password) " :
117+ if auth == f " Basic { username} : { password} " :
109118 context.params[" username" ] = username
110- return next .call(context: context )
119+ return next .call(context)
111120 else :
112- return Unauthorized()
121+ return Unauthorized(" Requires Basic Authentication " )
113122
114123# always add at the end of the middleware chain
115124struct NotFoundMiddleware (Middleware ):
116- fn call (self , context : Context) -> Response :
117- return NotFound()
125+ fn call (self , context : Context) -> HTTPResponse :
126+ return NotFound(String( " Not Found " ).as_bytes() )
118127
119- struct MiddlewareChain (HttpService ):
120- var middlewares : Array [Middleware]
128+ struct MiddlewareChain (HTTPService ):
129+ var middlewares : List [Middleware]
121130
122- fn __init__ (self ):
131+ fn __init__ (inout self ):
123132 self .middlewares = Array[Middleware]()
124133
125134 fn add (self , middleware : Middleware):
126- if middlewares.count == 0 :
127- middlewares.append(middleware)
135+ if self . middlewares.count == 0 :
136+ self . middlewares.append(middleware)
128137 else :
129- var last = middlewares[middlewares.count - 1 ]
138+ var last = self . middlewares[middlewares.count - 1 ]
130139 last.next = middleware
131- middlewares.append(middleware)
140+ self . middlewares.append(middleware)
132141
133- fn func (self , request : Request) -> Response:
134- self .add(NotFoundMiddleware())
135- var context = Context(request: request, response: response)
136- return middlewares[0 ].call(context: context)
142+ fn func (self , req : HTTPRequest) raises -> HTTPResponse:
143+ var context = Context(request)
144+ return self .middlewares[0 ].call(context)
137145
138146fn OK (body : Bytes) -> HTTPResponse:
139147 return OK(body, String(" text/plain" ))
@@ -145,7 +153,7 @@ fn OK(body: Bytes, content_type: String) -> HTTPResponse:
145153 )
146154
147155fn NotFound (body : Bytes) -> HTTPResponse:
148- return NotFoundResponse (body, String(" text/plain" ))
156+ return NotFound (body, String(" text/plain" ))
149157
150158fn NotFound (body : Bytes, content_type : String) -> HTTPResponse:
151159 return HTTPResponse(
@@ -166,7 +174,10 @@ fn Unauthorized(body: Bytes) -> HTTPResponse:
166174 return UnauthorizedResponse(body, String(" text/plain" ))
167175
168176fn Unauthorized (body : Bytes, content_type : String) -> HTTPResponse:
177+ var header = ResponseHeader(True , 401 , String(" Unauthorized" ).as_bytes(), content_type.as_bytes())
178+ header.headers[" WWW-Authenticate" ] = " Basic realm=\" Login Required\" "
179+
169180 return HTTPResponse(
170- ResponseHeader( True , 401 , String( " Unauthorized " ).as_bytes(), content_type.as_bytes()) ,
181+ header ,
171182 body,
172183 )
0 commit comments