Skip to content

Commit b93d808

Browse files
committed
fix socket bug, enable __del__ and __enter__, and switch poolmanager to use tuple of host, port, and scheme
1 parent 9a7c8d7 commit b93d808

File tree

3 files changed

+145
-71
lines changed

3 files changed

+145
-71
lines changed

lightbug_http/client.mojo

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,35 @@
11
from collections import Dict
2+
from utils import StringSlice
23
from memory import UnsafePointer
34
from lightbug_http.net import default_buffer_size
45
from lightbug_http.http import HTTPRequest, HTTPResponse, encode
56
from lightbug_http.header import Headers, HeaderKey
67
from lightbug_http.net import create_connection, TCPConnection
78
from lightbug_http.io.bytes import Bytes
89
from lightbug_http.utils import ByteReader, logger
9-
from lightbug_http.pool_manager import PoolManager
10+
from lightbug_http.pool_manager import PoolManager, Scheme, PoolKey
11+
12+
13+
fn parse_host_and_port(source: String, is_tls: Bool) raises -> (String, UInt16):
14+
"""Parses the host and port from a given string.
15+
16+
Args:
17+
source: The host uri to parse.
18+
is_tls: A boolean indicating whether the connection is secure.
19+
20+
Returns:
21+
A tuple containing the host and port.
22+
"""
23+
var port: UInt16
24+
if source.count(":") != 1:
25+
port = 443 if is_tls else 80
26+
return source, port
27+
28+
var host: String
29+
var reader = ByteReader(source.as_bytes())
30+
host = StringSlice(unsafe_from_utf8=reader.read_until(ord(":")))
31+
port = atol(StringSlice(unsafe_from_utf8=reader.read_bytes()[1:]))
32+
return host^, port
1033

1134

1235
struct Client:
@@ -30,7 +53,7 @@ struct Client:
3053
self.allow_redirects = allow_redirects
3154
self._connections = PoolManager[TCPConnection](cached_connections)
3255

33-
fn do(mut self, owned req: HTTPRequest) raises -> HTTPResponse:
56+
fn do(mut self, owned request: HTTPRequest) raises -> HTTPResponse:
3457
"""The `do` method is responsible for sending an HTTP request to a server and receiving the corresponding response.
3558
3659
It performs the following steps:
@@ -43,81 +66,67 @@ struct Client:
4366
Note: The code assumes that the `HTTPRequest` object passed as an argument has a valid URI with a host and port specified.
4467
4568
Args:
46-
req: An `HTTPRequest` object representing the request to be sent.
69+
request: An `HTTPRequest` object representing the request to be sent.
4770
4871
Returns:
4972
The received response.
5073
5174
Raises:
5275
Error: If there is a failure in sending or receiving the message.
5376
"""
54-
if req.uri.host == "":
77+
if request.uri.host == "":
5578
raise Error("Client.do: Request failed because the host field is empty.")
56-
var is_tls = False
5779

58-
if req.uri.is_https():
80+
var is_tls = False
81+
var scheme = Scheme.HTTP
82+
if request.uri.is_https():
5983
is_tls = True
84+
scheme = Scheme.HTTPS
6085

61-
var host_str: String
62-
var port: Int
63-
if ":" in req.uri.host:
64-
var host_port: List[String]
65-
try:
66-
host_port = req.uri.host.split(":")
67-
except:
68-
raise Error("Client.do: Failed to split host and port.")
69-
host_str = host_port[0]
70-
port = atol(host_port[1])
71-
else:
72-
host_str = req.uri.host
73-
if is_tls:
74-
port = 443
75-
else:
76-
port = 80
77-
86+
host, port = parse_host_and_port(request.uri.host, is_tls)
87+
var pool_key = PoolKey(host, port, scheme)
7888
var cached_connection = False
7989
var conn: TCPConnection
8090
try:
81-
conn = self._connections.take(host_str)
91+
conn = self._connections.take(pool_key)
8292
cached_connection = True
8393
except e:
8494
if str(e) == "PoolManager.take: Key not found.":
85-
conn = create_connection(host_str, port)
95+
conn = create_connection(host, port)
8696
else:
8797
logger.error(e)
8898
raise Error("Client.do: Failed to create a connection to host.")
8999

90100
var bytes_sent: Int
91101
try:
92-
bytes_sent = conn.write(encode(req))
102+
bytes_sent = conn.write(encode(request))
93103
except e:
94104
# Maybe peer reset ungracefully, so try a fresh connection
95105
if str(e) == "SendError: Connection reset by peer.":
96106
logger.debug("Client.do: Connection reset by peer. Trying a fresh connection.")
97107
conn.teardown()
98108
if cached_connection:
99-
return self.do(req^)
109+
return self.do(request^)
100110
logger.error("Client.do: Failed to send message.")
101111
raise e
102112

103113
# TODO: What if the response is too large for the buffer? We should read until the end of the response. (@thatstoasty)
104114
var new_buf = Bytes(capacity=default_buffer_size)
105-
106115
try:
107116
_ = conn.read(new_buf)
108117
except e:
109118
if str(e) == "EOF":
110119
conn.teardown()
111120
if cached_connection:
112-
return self.do(req^)
121+
return self.do(request^)
113122
raise Error("Client.do: No response received from the server.")
114123
else:
115124
logger.error(e)
116125
raise Error("Client.do: Failed to read response from peer.")
117126

118-
var res: HTTPResponse
127+
var response: HTTPResponse
119128
try:
120-
res = HTTPResponse.from_bytes(new_buf, conn)
129+
response = HTTPResponse.from_bytes(new_buf, conn)
121130
except e:
122131
logger.error("Failed to parse a response...")
123132
try:
@@ -127,19 +136,19 @@ struct Client:
127136
raise e
128137

129138
# Redirects should not keep the connection alive, as redirects can send the client to a different server.
130-
if self.allow_redirects and res.is_redirect():
139+
if self.allow_redirects and response.is_redirect():
131140
conn.teardown()
132-
return self._handle_redirect(req^, res^)
141+
return self._handle_redirect(request^, response^)
133142
# Server told the client to close the connection, we can assume the server closed their side after sending the response.
134-
elif res.connection_close():
143+
elif response.connection_close():
135144
conn.teardown()
136145
# Otherwise, persist the connection by giving it back to the pool manager.
137146
else:
138-
self._connections.give(host_str, conn^)
139-
return res
147+
self._connections.give(pool_key, conn^)
148+
return response
140149

141150
fn _handle_redirect(
142-
mut self, owned original_req: HTTPRequest, owned original_response: HTTPResponse
151+
mut self, owned original_request: HTTPRequest, owned original_response: HTTPResponse
143152
) raises -> HTTPResponse:
144153
var new_uri: URI
145154
var new_location: String
@@ -150,9 +159,9 @@ struct Client:
150159

151160
if new_location and new_location.startswith("http"):
152161
new_uri = URI.parse(new_location)
153-
original_req.headers[HeaderKey.HOST] = new_uri.host
162+
original_request.headers[HeaderKey.HOST] = new_uri.host
154163
else:
155-
new_uri = original_req.uri
164+
new_uri = original_request.uri
156165
new_uri.path = new_location
157-
original_req.uri = new_uri
158-
return self.do(original_req^)
166+
original_request.uri = new_uri
167+
return self.do(original_request^)

lightbug_http/pool_manager.mojo

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,40 +9,100 @@ from lightbug_http.utils import logger
99
from lightbug_http.owning_list import OwningList
1010

1111

12+
@value
13+
struct Scheme(Hashable, EqualityComparable, Representable, Stringable, Writable):
14+
var value: String
15+
alias HTTP = Self("http")
16+
alias HTTPS = Self("https")
17+
18+
fn __hash__(self) -> UInt:
19+
return hash(self.value)
20+
21+
fn __eq__(self, other: Self) -> Bool:
22+
return self.value == other.value
23+
24+
fn __ne__(self, other: Self) -> Bool:
25+
return self.value != other.value
26+
27+
fn write_to[W: Writer, //](self, mut writer: W) -> None:
28+
writer.write("Scheme(value=", repr(self.value), ")")
29+
30+
fn __repr__(self) -> String:
31+
return String.write(self)
32+
33+
fn __str__(self) -> String:
34+
return self.value.upper()
35+
36+
37+
@value
38+
struct PoolKey(Hashable, KeyElement):
39+
var host: String
40+
var port: UInt16
41+
var scheme: Scheme
42+
43+
fn __init__(out self, host: String, port: UInt16, scheme: Scheme):
44+
self.host = host
45+
self.port = port
46+
self.scheme = scheme
47+
48+
fn __hash__(self) -> UInt:
49+
# TODO: Very rudimentary hash. We probably need to actually have an actual hash function here.
50+
# Since Tuple doesn't have one.
51+
return hash(self.host) + hash(self.port) + hash(self.scheme)
52+
53+
fn __eq__(self, other: Self) -> Bool:
54+
return self.host == other.host and self.port == other.port and self.scheme == other.scheme
55+
56+
fn __ne__(self, other: Self) -> Bool:
57+
return self.host != other.host or self.port != other.port or self.scheme != other.scheme
58+
59+
fn __str__(self) -> String:
60+
var result = String()
61+
result.write(self.scheme.value, "://", self.host, ":", str(self.port))
62+
return result
63+
64+
fn __repr__(self) -> String:
65+
return String.write(self)
66+
67+
fn write_to[W: Writer, //](self, mut writer: W) -> None:
68+
writer.write(
69+
"PoolKey(", "scheme=", repr(self.scheme.value), ", host=", repr(self.host), ", port=", str(self.port), ")"
70+
)
71+
72+
1273
struct PoolManager[ConnectionType: Connection]():
1374
var _connections: OwningList[ConnectionType]
1475
var _capacity: Int
15-
var mapping: Dict[String, Int]
76+
var mapping: Dict[PoolKey, Int]
1677

1778
fn __init__(out self, capacity: Int = 10):
1879
self._connections = OwningList[ConnectionType](capacity=capacity)
1980
self._capacity = capacity
20-
self.mapping = Dict[String, Int]()
81+
self.mapping = Dict[PoolKey, Int]()
2182

2283
fn __del__(owned self):
2384
logger.debug(
2485
"PoolManager shutting down and closing remaining connections before destruction:", self._connections.size
2586
)
2687
self.clear()
2788

28-
fn give(mut self, host: String, owned value: ConnectionType) raises:
29-
if host in self.mapping:
30-
self._connections[self.mapping[host]] = value^
89+
fn give(mut self, key: PoolKey, owned value: ConnectionType) raises:
90+
if key in self.mapping:
91+
self._connections[self.mapping[key]] = value^
3192
return
3293

3394
if self._connections.size == self._capacity:
3495
raise Error("PoolManager.give: Cache is full.")
3596

36-
self._connections[self._connections.size] = value^
37-
self.mapping[host] = self._connections.size
38-
self._connections.size += 1
39-
logger.debug("Checked in connection for peer:", host + ", at index:", self._connections.size)
97+
self._connections.append(value^)
98+
self.mapping[key] = self._connections.size - 1
99+
logger.debug("Checked in connection for peer:", str(key) + ", at index:", self._connections.size)
40100

41-
fn take(mut self, host: String) raises -> ConnectionType:
101+
fn take(mut self, key: PoolKey) raises -> ConnectionType:
42102
var index: Int
43103
try:
44-
index = self.mapping[host]
45-
_ = self.mapping.pop(host)
104+
index = self.mapping[key]
105+
_ = self.mapping.pop(key)
46106
except:
47107
raise Error("PoolManager.take: Key not found.")
48108

@@ -52,7 +112,7 @@ struct PoolManager[ConnectionType: Connection]():
52112
if kv[].value > index:
53113
self.mapping[kv[].key] -= 1
54114

55-
logger.debug("Checked out connection for peer:", host + ", from index:", self._connections.size + 1)
115+
logger.debug("Checked out connection for peer:", str(key) + ", from index:", self._connections.size + 1)
56116
return connection^
57117

58118
fn clear(mut self):
@@ -65,14 +125,14 @@ struct PoolManager[ConnectionType: Connection]():
65125
logger.error("Failed to tear down connection. Error:", e)
66126
self.mapping.clear()
67127

68-
fn __contains__(self, host: String) -> Bool:
69-
return host in self.mapping
128+
fn __contains__(self, key: PoolKey) -> Bool:
129+
return key in self.mapping
70130

71-
fn __setitem__(mut self, host: String, owned value: ConnectionType) raises -> None:
72-
if host in self.mapping:
73-
self._connections[self.mapping[host]] = value^
131+
fn __setitem__(mut self, key: PoolKey, owned value: ConnectionType) raises -> None:
132+
if key in self.mapping:
133+
self._connections[self.mapping[key]] = value^
74134
else:
75-
self.give(host, value^)
135+
self.give(key, value^)
76136

77-
fn __getitem__(self, host: String) raises -> ref [self._connections] ConnectionType:
78-
return self._connections[self.mapping[host]]
137+
fn __getitem__(self, key: PoolKey) raises -> ref [self._connections] ConnectionType:
138+
return self._connections[self.mapping[key]]

lightbug_http/socket.mojo

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,16 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri
148148
self.fd = existing.fd
149149
self.socket_type = existing.socket_type
150150
self.protocol = existing.protocol
151+
151152
self._local_address = existing._local_address^
153+
existing._local_address = AddrType()
152154
self._remote_address = existing._remote_address^
155+
existing._remote_address = AddrType()
156+
153157
self._closed = existing._closed
158+
existing._closed = True
154159
self._connected = existing._connected
160+
existing._connected = False
155161

156162
fn teardown(mut self) raises:
157163
"""Close the socket and free the file descriptor."""
@@ -169,17 +175,16 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri
169175
raise e
170176

171177
# TODO: Removed until we can determine why __del__ bugs out in the client flow, but not server flow?
172-
# fn __enter__(owned self) -> Self:
173-
# return self^
178+
fn __enter__(owned self) -> Self:
179+
return self^
174180

175181
# TODO: Seems to be bugged if this is included. Mojo tries to delete a mystical 0 fd socket that was never initialized?
176-
# fn __del__(owned self):
177-
# """Close the socket when the object is deleted."""
178-
# logger.info("In socket del", self)
179-
# try:
180-
# self.teardown()
181-
# except e:
182-
# logger.debug("Socket.__del__: Failed to close socket during deletion:", e)
182+
fn __del__(owned self):
183+
"""Close the socket when the object is deleted."""
184+
try:
185+
self.teardown()
186+
except e:
187+
logger.debug("Socket.__del__: Failed to close socket during deletion:", e)
183188

184189
fn __str__(self) -> String:
185190
return String.write(self)

0 commit comments

Comments
 (0)