Skip to content

Commit bde712a

Browse files
Refactor SocketWrapper: use shared_ptr for sock_fd to handle automatic socket close
1 parent aad1686 commit bde712a

File tree

4 files changed

+93
-109
lines changed

4 files changed

+93
-109
lines changed

libraries/Ethernet/examples/AdvancedChatServer/AdvancedChatServer.ino

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@
3535

3636
// The IP address will be dependent on your local network.
3737
// gateway and subnet are optional:
38-
IPAddress ip(192, 168, 2, 9); // Un IP libero nella subnet 192.168.2.x
39-
IPAddress myDns(192, 168, 2, 1); // Di solito si usa il gateway come DNS
40-
IPAddress gateway(192, 168, 2, 1); // IP del bridge/router
41-
IPAddress subnet(255, 255, 255, 0); // Subnet mask della rete
38+
IPAddress ip(192, 168, 1, 177);
39+
IPAddress myDns(192, 168, 1, 1);
40+
IPAddress gateway(192, 168, 1, 1);
41+
IPAddress subnet(255, 255, 255, 0);
42+
4243

4344
// telnet defaults to port 23
4445
ZephyrServer server(23);
@@ -105,7 +106,6 @@ void loop() {
105106
// read bytes from a client
106107
byte buffer[80];
107108
int count = clients[i].read(buffer, 80);
108-
Serial.println(count);
109109
// write the bytes to all other connected clients
110110
for (byte j=0; j < 8; j++) {
111111
if (j != i && clients[j].connected()) {

libraries/SocketWrapper/SocketWrapper.h

Lines changed: 79 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,38 @@
77
#endif
88

99
#include <zephyr/net/socket.h>
10+
#include <memory>
11+
#include <cstring>
1012

1113
class ZephyrSocketWrapper {
1214
protected:
13-
int sock_fd;
15+
std::shared_ptr<int> sock_fd;
1416
bool is_ssl = false;
1517
int ssl_sock_temp_char = -1;
1618

19+
// custom deleter for shared_ptr to close automatically the socket
20+
static auto socket_deleter() {
21+
return [](int* fd) {
22+
if (fd && *fd != -1) {
23+
::close(*fd);
24+
delete fd;
25+
}
26+
};
27+
}
28+
1729
public:
18-
ZephyrSocketWrapper() : sock_fd(-1) {
19-
}
30+
ZephyrSocketWrapper() = default;
2031

21-
ZephyrSocketWrapper(int sock_fd) : sock_fd(sock_fd) {
22-
}
32+
ZephyrSocketWrapper(int fd)
33+
: sock_fd(std::shared_ptr<int>(new int(fd), socket_deleter())) {}
2334

24-
~ZephyrSocketWrapper() {
25-
// close();
26-
}
35+
~ZephyrSocketWrapper() = default; // socket close managed by shared_ptr
2736

2837
bool connect(const char *host, uint16_t port) {
2938

3039
// Resolve address
31-
struct addrinfo hints = {0};
32-
struct addrinfo *res = nullptr;
33-
bool rv = true;
40+
struct addrinfo hints;
41+
struct addrinfo *res;
3442

3543
hints.ai_family = AF_INET;
3644
hints.ai_socktype = SOCK_STREAM;
@@ -49,31 +57,26 @@ class ZephyrSocketWrapper {
4957
}
5058

5159
if (ret != 0) {
52-
rv = false;
53-
goto exit;
60+
if (res) freeaddrinfo(res);
61+
return false;
5462
}
5563

56-
sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
57-
if (sock_fd < 0) {
58-
rv = false;
59-
60-
goto exit;
64+
int raw_sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
65+
if (raw_sock_fd < 0) {
66+
if (res) freeaddrinfo(res);
67+
return false;
6168
}
6269

63-
if (::connect(sock_fd, res->ai_addr, res->ai_addrlen) < 0) {
64-
::close(sock_fd);
65-
sock_fd = -1;
66-
rv = false;
67-
goto exit;
68-
}
70+
sock_fd = std::shared_ptr<int>(new int(raw_sock_fd), socket_deleter());
6971

70-
exit:
71-
if (res != nullptr) {
72-
freeaddrinfo(res);
73-
res = nullptr;
72+
if (::connect(*sock_fd, res->ai_addr, res->ai_addrlen) < 0) {
73+
sock_fd.reset();
74+
if (res) freeaddrinfo(res);
75+
return false;
7476
}
7577

76-
return rv;
78+
if (res) freeaddrinfo(res);
79+
return true;
7780
}
7881

7982
bool connect(IPAddress host, uint16_t port) {
@@ -85,14 +88,13 @@ class ZephyrSocketWrapper {
8588
addr.sin_port = htons(port);
8689
inet_pton(AF_INET, _host, &addr.sin_addr);
8790

88-
sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
89-
if (sock_fd < 0) {
90-
return false;
91-
}
91+
int raw_sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
92+
if (raw_sock_fd < 0) return false;
93+
94+
sock_fd = std::shared_ptr<int>(new int(raw_sock_fd), socket_deleter());
9295

93-
if (::connect(sock_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
94-
::close(sock_fd);
95-
sock_fd = -1;
96+
if (::connect(*sock_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
97+
sock_fd.reset();
9698
return false;
9799
}
98100

@@ -111,7 +113,6 @@ class ZephyrSocketWrapper {
111113

112114
int resolve_attempts = 100;
113115
int ret;
114-
bool rv = false;
115116

116117
sec_tag_t sec_tag_opt[] = {
117118
CA_CERTIFICATE_TAG,
@@ -133,7 +134,8 @@ class ZephyrSocketWrapper {
133134
}
134135

135136
if (ret != 0) {
136-
goto exit;
137+
if (res) freeaddrinfo(res);
138+
return false;
137139
}
138140

139141
if (ca_certificate_pem != nullptr) {
@@ -144,35 +146,29 @@ class ZephyrSocketWrapper {
144146
}
145147
}
146148

147-
sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TLS_1_2);
148-
if (sock_fd < 0) {
149-
goto exit;
150-
}
149+
int raw_sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TLS_1_2);
150+
if (raw_sock_fd < 0) {
151+
if (res) freeaddrinfo(res);
152+
return false;
153+
}
151154

152-
if (setsockopt(sock_fd, SOL_TLS, TLS_HOSTNAME, host, strlen(host)) ||
153-
setsockopt(sock_fd, SOL_TLS, TLS_SEC_TAG_LIST, sec_tag_opt, sizeof(sec_tag_opt)) ||
154-
setsockopt(sock_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout_opt, sizeof(timeout_opt))) {
155-
goto exit;
156-
}
155+
sock_fd = std::shared_ptr<int>(new int(raw_sock_fd), socket_deleter());
157156

158-
if (::connect(sock_fd, res->ai_addr, res->ai_addrlen) < 0) {
157+
if (setsockopt(*sock_fd, SOL_TLS, TLS_HOSTNAME, host, strlen(host)) ||
158+
setsockopt(*sock_fd, SOL_TLS, TLS_SEC_TAG_LIST, sec_tag_opt, sizeof(sec_tag_opt)) ||
159+
setsockopt(*sock_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout_opt, sizeof(timeout_opt))) {
159160
goto exit;
160161
}
161162

162-
rv = true;
163-
is_ssl = true;
164-
165-
exit:
166-
if (res != nullptr) {
167-
freeaddrinfo(res);
168-
res = nullptr;
163+
if (::connect(*sock_fd, res->ai_addr, res->ai_addrlen) < 0) {
164+
sock_fd.reset();
165+
if (res) freeaddrinfo(res);
166+
return false;
169167
}
168+
is_ssl = true;
170169

171-
if (!rv && sock_fd >= 0) {
172-
::close(sock_fd);
173-
sock_fd = -1;
174-
}
175-
return rv;
170+
if (res) freeaddrinfo(res);
171+
return true;
176172
}
177173
#endif
178174

@@ -189,9 +185,9 @@ class ZephyrSocketWrapper {
189185
if (ssl_sock_temp_char != -1) {
190186
return 1;
191187
}
192-
count = ::recv(sock_fd, &ssl_sock_temp_char, 1, MSG_DONTWAIT);
188+
count = ::recv(*sock_fd, &ssl_sock_temp_char, 1, MSG_DONTWAIT);
193189
} else {
194-
zsock_ioctl(sock_fd, ZFD_IOCTL_FIONREAD, &count);
190+
zsock_ioctl(*sock_fd, ZFD_IOCTL_FIONREAD, &count);
195191
}
196192
if (count <= 0) {
197193
delay(1);
@@ -201,31 +197,25 @@ class ZephyrSocketWrapper {
201197
}
202198

203199
int recv(uint8_t *buffer, size_t size, int flags = MSG_DONTWAIT) {
204-
if (sock_fd == -1) {
205-
return -1;
206-
}
200+
if (!sock_fd) return -1;
201+
207202
// TODO: see available()
208203
if (ssl_sock_temp_char != -1) {
209-
int ret = ::recv(sock_fd, &buffer[1], size - 1, flags);
204+
int ret = ::recv(*sock_fd, &buffer[1], size - 1, flags);
210205
buffer[0] = ssl_sock_temp_char;
211206
ssl_sock_temp_char = -1;
212207
return ret + 1;
213208
}
214-
return ::recv(sock_fd, buffer, size, flags);
209+
return ::recv(*sock_fd, buffer, size, flags);
215210
}
216211

217212
int send(const uint8_t *buffer, size_t size) {
218-
if (sock_fd == -1) {
219-
return -1;
220-
}
221-
return ::send(sock_fd, buffer, size, 0);
213+
if (!sock_fd) return -1;
214+
return ::send(*sock_fd, buffer, size, 0);
222215
}
223216

224217
void close() {
225-
if (sock_fd != -1) {
226-
::close(sock_fd);
227-
sock_fd = -1;
228-
}
218+
sock_fd.reset();
229219
}
230220

231221
bool bind(uint16_t port) {
@@ -234,54 +224,46 @@ class ZephyrSocketWrapper {
234224
addr.sin_port = htons(port);
235225
addr.sin_addr.s_addr = INADDR_ANY;
236226

237-
sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
238-
if (sock_fd < 0) {
239-
return false;
240-
}
227+
int raw_sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
228+
if (raw_sock_fd < 0) return false;
241229

242-
zsock_ioctl(sock_fd, ZFD_IOCTL_FIONBIO);
230+
sock_fd = std::shared_ptr<int>(new int(raw_sock_fd), socket_deleter());
243231

244-
if (::bind(sock_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
245-
::close(sock_fd);
246-
sock_fd = -1;
232+
zsock_ioctl(*sock_fd, ZFD_IOCTL_FIONBIO);
233+
234+
if (::bind(*sock_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
235+
sock_fd.reset();
247236
return false;
248237
}
249238

250239
return true;
251240
}
252241

253242
bool listen(int backlog = 5) {
254-
if (sock_fd == -1) {
255-
return false;
256-
}
243+
if (!sock_fd) return false;
257244

258-
if (::listen(sock_fd, backlog) < 0) {
259-
::close(sock_fd);
260-
sock_fd = -1;
245+
if (::listen(*sock_fd, backlog) < 0) {
246+
sock_fd.reset();
261247
return false;
262248
}
263249

264250
return true;
265251
}
266252

267253
int accept() {
268-
if (sock_fd == -1) {
269-
return -1;
270-
}
254+
if (!sock_fd) return false;
271255

272-
return ::accept(sock_fd, nullptr, nullptr);
256+
return ::accept(*sock_fd, nullptr, nullptr);
273257
}
274258

275259
String remoteIP() {
276-
if (sock_fd == -1) {
277-
return {};
278-
}
260+
if (!sock_fd) return {};
279261

280262
struct sockaddr_storage addr;
281263
socklen_t addr_len = sizeof(addr);
282264
char ip_str[INET6_ADDRSTRLEN] = {0};
283265

284-
if (::getpeername(sock_fd, (struct sockaddr *)&addr, &addr_len) == 0) {
266+
if (::getpeername(*sock_fd, (struct sockaddr *)&addr, &addr_len) == 0) {
285267
if (addr.ss_family == AF_INET) {
286268
struct sockaddr_in *s = (struct sockaddr_in *)&addr;
287269
::inet_ntop(AF_INET, &s->sin_addr, ip_str, sizeof(ip_str));

libraries/SocketWrapper/ZephyrClient.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class ZephyrClient : public arduino::Client, ZephyrSocketWrapper {
1111

1212
protected:
1313
void setSocket(int sock) {
14-
sock_fd = sock;
14+
sock_fd = std::shared_ptr<int>(new int(sock), ZephyrSocketWrapper::socket_deleter());
1515
_connected = true;
1616
}
1717

@@ -42,10 +42,10 @@ class ZephyrClient : public arduino::Client, ZephyrSocketWrapper {
4242
#endif
4343

4444
uint8_t connected() override {
45-
if (sock_fd == -1) return false;
45+
if (!sock_fd || *sock_fd == -1) return false;
4646

4747
uint8_t buf;
48-
int ret = ::recv(sock_fd, &buf, 1, MSG_PEEK | MSG_DONTWAIT);
48+
int ret = ::recv(*sock_fd, &buf, 1, MSG_PEEK | MSG_DONTWAIT);
4949
if (ret == 0) {
5050
stop();
5151
return false;
@@ -102,7 +102,7 @@ class ZephyrClient : public arduino::Client, ZephyrSocketWrapper {
102102
}
103103

104104
operator bool() {
105-
return sock_fd != -1;
105+
return sock_fd && *sock_fd != -1;
106106
}
107107

108108
String remoteIP() {

libraries/SocketWrapper/ZephyrServer.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ class ZephyrServer : public arduino::Server, ZephyrSocketWrapper {
3636
}
3737

3838
explicit operator bool() {
39-
return sock_fd != -1;
39+
return sock_fd && *sock_fd != -1;
4040
}
4141

4242
ZephyrClient accept(uint8_t *status = nullptr) {
4343
ZephyrClient client;
44-
int sock = ZephyrSocketWrapper::accept();
45-
client.setSocket(sock);
44+
int client_fd = ZephyrSocketWrapper::accept();
45+
if (client_fd >= 0) {
46+
client.setSocket(client_fd);
47+
}
4648
return client;
4749
}
4850

0 commit comments

Comments
 (0)