Skip to content

Commit 3c15688

Browse files
authored
Router Network API updates regarding tcp/listen. (#643)
1 parent 4057e63 commit 3c15688

File tree

2 files changed

+62
-30
lines changed

2 files changed

+62
-30
lines changed

cmd/arduino-router/network-api/network-api.go

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ func Register(router *msgpackrouter.Router) {
2222
_ = router.RegisterMethod("tcp/connect", tcpConnect)
2323

2424
_ = router.RegisterMethod("tcp/listen", tcpListen)
25-
_ = router.RegisterMethod("tcp/accept", tcpAccept)
25+
_ = router.RegisterMethod("tcp/closeListener", tcpCloseListener)
2626

27+
_ = router.RegisterMethod("tcp/accept", tcpAccept)
2728
_ = router.RegisterMethod("tcp/read", tcpRead)
2829
_ = router.RegisterMethod("tcp/write", tcpWrite)
2930
_ = router.RegisterMethod("tcp/close", tcpClose)
@@ -83,13 +84,19 @@ func tcpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (
8384
}
8485

8586
func tcpListen(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
86-
if len(params) != 1 {
87-
return nil, []any{1, "Invalid number of parameters, expected listen address"}
87+
if len(params) != 2 {
88+
return nil, []any{1, "Invalid number of parameters, expected listen address and port"}
8889
}
8990
listenAddr, ok := params[0].(string)
9091
if !ok {
9192
return nil, []any{1, "Invalid parameter type, expected string for listen address"}
9293
}
94+
listenPort, ok := msgpackrpc.ToUint(params[1])
95+
if !ok {
96+
return nil, []any{1, "Invalid parameter type, expected uint16 for listen port"}
97+
}
98+
99+
listenAddr = net.JoinHostPort(listenAddr, strconv.FormatUint(uint64(listenPort), 10))
93100

94101
listener, err := net.Listen("tcp", listenAddr)
95102
if err != nil {
@@ -143,33 +150,50 @@ func tcpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r
143150

144151
lock.Lock()
145152
conn, existsConn := liveConnections[id]
146-
listener, existsListener := liveListeners[id]
147153
if existsConn {
148154
delete(liveConnections, id)
149155
}
150-
if existsListener {
151-
delete(liveListeners, id)
152-
}
153156
lock.Unlock()
154157

155-
if !existsConn && !existsListener {
158+
if !existsConn {
156159
return nil, []any{2, fmt.Sprintf("Connection not found for ID: %d", id)}
157160
}
158161

159-
// Close the connection or listener if it exists
160-
// We do not return an error if the close operation fails, as it is not critical,
162+
// Close the connection if it exists
163+
// We do not return an error to the caller if the close operation fails, as it is not critical,
161164
// but we only log the error for debugging purposes.
162-
if existsConn {
163-
if err := conn.Close(); err != nil {
164-
return err.Error(), nil
165-
}
165+
if err := conn.Close(); err != nil {
166+
return err.Error(), nil
166167
}
168+
return "", nil
169+
}
170+
171+
func tcpCloseListener(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
172+
if len(params) != 1 {
173+
return nil, []any{1, "Invalid number of parameters, expected listener ID"}
174+
}
175+
id, ok := msgpackrpc.ToUint(params[0])
176+
if !ok {
177+
return nil, []any{1, "Invalid parameter type, expected int for listener ID"}
178+
}
179+
180+
lock.Lock()
181+
listener, existsListener := liveListeners[id]
167182
if existsListener {
168-
if err := listener.Close(); err != nil {
169-
return err.Error(), nil
170-
}
183+
delete(liveListeners, id)
184+
}
185+
lock.Unlock()
186+
187+
if !existsListener {
188+
return nil, []any{2, fmt.Sprintf("Listener not found for ID: %d", id)}
171189
}
172190

191+
// Close the listener if it exists
192+
// We do not return an error to the caller if the close operation fails, as it is not critical,
193+
// but we only log the error for debugging purposes.
194+
if err := listener.Close(); err != nil {
195+
return err.Error(), nil
196+
}
173197
return "", nil
174198
}
175199

cmd/arduino-router/network-api/network-api_test.go

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package networkapi
22

33
import (
4+
"fmt"
45
"sync"
56
"testing"
67

@@ -143,16 +144,14 @@ const testCert = "-----BEGIN CERTIFICATE-----\n" +
143144
func TestNetworkAPI(t *testing.T) {
144145
ctx := t.Context()
145146
var rpc *msgpackrpc.Connection
146-
listID, err := tcpListen(ctx, rpc, []any{"localhost:9999"})
147+
listID, err := tcpListen(ctx, rpc, []any{"localhost", 9999})
147148
require.Nil(t, err)
148149
require.Equal(t, uint(1), listID)
149150

150151
var wg sync.WaitGroup
151-
wg.Add(1)
152-
go func() {
152+
wg.Go(func() {
153153
connID, err := tcpConnect(ctx, rpc, []any{"localhost", uint16(9999)})
154154
require.Nil(t, err)
155-
require.Equal(t, uint(2), connID)
156155

157156
n, err := tcpWrite(ctx, rpc, []any{connID, []byte("Hello")})
158157
require.Nil(t, err)
@@ -163,17 +162,12 @@ func TestNetworkAPI(t *testing.T) {
163162
require.Equal(t, "", res)
164163

165164
res, err = tcpClose(ctx, rpc, []any{connID})
166-
require.Equal(t, []any{2, "Connection not found for ID: 2"}, err)
165+
require.Equal(t, []any{2, fmt.Sprintf("Connection not found for ID: %d", connID)}, err)
167166
require.Nil(t, res)
168-
169-
wg.Done()
170-
}()
171-
172-
wg.Wait()
167+
})
173168

174169
connID, err := tcpAccept(ctx, rpc, []any{listID})
175170
require.Nil(t, err)
176-
require.Equal(t, uint(3), connID)
177171

178172
buff, err := tcpRead(ctx, rpc, []any{connID, 3})
179173
require.Nil(t, err)
@@ -187,16 +181,28 @@ func TestNetworkAPI(t *testing.T) {
187181
require.Equal(t, []any{3, "Failed to read from connection: EOF"}, err)
188182
require.Nil(t, buff)
189183

190-
res, err := tcpClose(ctx, rpc, []any{connID})
184+
res, err := tcpCloseListener(ctx, rpc, []any{connID})
185+
require.Equal(t, []any{2, fmt.Sprintf("Listener not found for ID: %d", connID)}, err)
186+
require.Nil(t, res)
187+
188+
res, err = tcpClose(ctx, rpc, []any{connID})
191189
require.Nil(t, err)
192190
require.Equal(t, "", res)
193191

194192
res, err = tcpClose(ctx, rpc, []any{listID})
193+
require.Equal(t, []any{2, fmt.Sprintf("Connection not found for ID: %d", listID)}, err)
194+
require.Nil(t, res)
195+
196+
res, err = tcpCloseListener(ctx, rpc, []any{listID})
195197
require.Nil(t, err)
196198
require.Equal(t, "", res)
197199

198200
res, err = tcpClose(ctx, rpc, []any{listID})
199-
require.Equal(t, []any{2, "Connection not found for ID: 1"}, err)
201+
require.Equal(t, []any{2, fmt.Sprintf("Connection not found for ID: %d", listID)}, err)
202+
require.Nil(t, res)
203+
204+
res, err = tcpCloseListener(ctx, rpc, []any{listID})
205+
require.Equal(t, []any{2, fmt.Sprintf("Listener not found for ID: %d", listID)}, err)
200206
require.Nil(t, res)
201207

202208
// Test SSL connection
@@ -212,4 +218,6 @@ func TestNetworkAPI(t *testing.T) {
212218
connIDSSL, err = tcpConnectSSL(ctx, rpc, []any{"www.arduino.cc", uint16(443), testCert})
213219
require.Equal(t, []any{2, "Failed to connect to server: tls: failed to verify certificate: x509: certificate signed by unknown authority"}, err)
214220
require.Nil(t, connIDSSL)
221+
222+
wg.Wait()
215223
}

0 commit comments

Comments
 (0)