Skip to content

Commit c8e287c

Browse files
giulio93lucaroncalucarin91
authored
Add SSH Port Forwarding (#677)
Co-authored-by: giulio93 <pilotto.giulio@gmail.com> Co-authored-by: Luca Ronca <luca.ronca07@gmail.com> Co-authored-by: lucarin91 <lucarin@protonmail.com>
1 parent ad4df6b commit c8e287c

File tree

4 files changed

+214
-38
lines changed

4 files changed

+214
-38
lines changed

pkg/board/remote/adb/adb.go

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ import (
88
"io"
99
"io/fs"
1010
"log/slog"
11-
"math/rand/v2"
12-
"net"
1311
"os"
1412
"os/user"
1513
"path"
@@ -20,6 +18,7 @@ import (
2018
"github.com/arduino/go-paths-helper"
2119

2220
"github.com/bcmi-labs/orchestrator/pkg/board/remote"
21+
"github.com/bcmi-labs/orchestrator/pkg/board/remote/common"
2322
)
2423

2524
const username = "arduino"
@@ -58,7 +57,7 @@ func FromHost(host string, adbPath string) (*ADBConnection, error) {
5857
}
5958

6059
func (a *ADBConnection) Forward(ctx context.Context, remotePort int) (int, error) {
61-
hostAvailablePort, err := getAvailablePort()
60+
hostAvailablePort, err := common.GetAvailablePort()
6261
if err != nil {
6362
return 0, fmt.Errorf("failed to find an available port: %w", err)
6463
}
@@ -311,35 +310,3 @@ func FindAdbPath() string {
311310

312311
return adbPath
313312
}
314-
315-
func isPortAvailable(port int) bool {
316-
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
317-
if err != nil {
318-
return false
319-
}
320-
listener.Close()
321-
return true
322-
}
323-
324-
func getRandomPort() int {
325-
port := 1000 + rand.IntN(9000) // nolint:gosec
326-
return port
327-
}
328-
329-
const forwardPortAttempts = 10
330-
331-
func getAvailablePort() (int, error) {
332-
tried := make(map[int]any, forwardPortAttempts)
333-
for len(tried) < forwardPortAttempts {
334-
port := getRandomPort()
335-
if _, seen := tried[port]; seen {
336-
continue
337-
}
338-
tried[port] = struct{}{}
339-
340-
if isPortAvailable(port) {
341-
return port, nil
342-
}
343-
}
344-
return 0, fmt.Errorf("no available port found in range 1000-9999 after %d attempts", forwardPortAttempts)
345-
}

pkg/board/remote/common/common.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package common
2+
3+
import (
4+
"fmt"
5+
"math/rand/v2"
6+
"net"
7+
)
8+
9+
func isPortAvailable(port int) bool {
10+
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
11+
if err != nil {
12+
return false
13+
}
14+
listener.Close()
15+
return true
16+
}
17+
18+
func getRandomPort() int {
19+
port := 1000 + rand.IntN(9000) // nolint:gosec
20+
return port
21+
}
22+
23+
const forwardPortAttempts = 10
24+
25+
func GetAvailablePort() (int, error) {
26+
tried := make(map[int]any, forwardPortAttempts)
27+
for len(tried) < forwardPortAttempts {
28+
port := getRandomPort()
29+
if _, seen := tried[port]; seen {
30+
continue
31+
}
32+
tried[port] = struct{}{}
33+
34+
if isPortAvailable(port) {
35+
return port, nil
36+
}
37+
}
38+
return 0, fmt.Errorf("no available port found in range 1000-9999 after %d attempts", forwardPortAttempts)
39+
}

pkg/board/remote/remote_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
package remote_test
22

33
import (
4+
"context"
5+
"fmt"
6+
47
"io"
8+
"os/exec"
9+
"strconv"
510
"strings"
611
"testing"
712

813
"github.com/stretchr/testify/assert"
914
"github.com/stretchr/testify/require"
1015

16+
"github.com/bcmi-labs/orchestrator/cmd/feedback"
1117
"github.com/bcmi-labs/orchestrator/pkg/board/remote"
1218
"github.com/bcmi-labs/orchestrator/pkg/board/remote/adb"
1319
"github.com/bcmi-labs/orchestrator/pkg/board/remote/local"
@@ -162,4 +168,80 @@ func TestSSHShell(t *testing.T) {
162168
})
163169
}
164170
}
171+
172+
}
173+
174+
func TestSSHForwarder(t *testing.T) {
175+
name, _, sshPort := testtools.StartAdbDContainer(t)
176+
t.Cleanup(func() { testtools.StopAdbDContainer(t, name) })
177+
178+
conn, err := ssh.FromHost("arduino", "arduino", fmt.Sprintf("%s:%s", "localhost", sshPort))
179+
require.NoError(t, err)
180+
181+
t.Run("Forward ADB", func(t *testing.T) {
182+
ctx, cancel := context.WithCancel(t.Context())
183+
defer cancel()
184+
185+
forwardPort, err := conn.Forward(ctx, 5555)
186+
if err != nil {
187+
t.Errorf("Forward failed: %v", err)
188+
}
189+
if forwardPort <= 0 || forwardPort > 65535 {
190+
t.Fatalf("invalid port: %d", forwardPort)
191+
}
192+
adb_forwarded_endpoint := fmt.Sprintf("localhost:%s", strconv.Itoa(forwardPort))
193+
194+
out, err := exec.Command("adb", "connect", adb_forwarded_endpoint).CombinedOutput()
195+
require.NoError(t, err, "adb connect output: %q", out)
196+
197+
cmd := exec.Command("adb", "-s", adb_forwarded_endpoint, "shell", "echo", "Hello, World!")
198+
out, err = cmd.CombinedOutput()
199+
require.NoError(t, err, "command output: %q", out)
200+
feedback.Printf("Command output:\n%s\n", string(out))
201+
require.NotNil(t, string(out))
202+
})
203+
}
204+
205+
func TestSSHKillForwarder(t *testing.T) {
206+
name, _, sshPort := testtools.StartAdbDContainer(t)
207+
t.Cleanup(func() { testtools.StopAdbDContainer(t, name) })
208+
209+
conn, err := ssh.FromHost("arduino", "arduino", fmt.Sprintf("%s:%s", "localhost", sshPort))
210+
require.NoError(t, err)
211+
212+
t.Run("KillAllForwards", func(t *testing.T) {
213+
ctx, cancel := context.WithCancel(t.Context())
214+
defer cancel()
215+
216+
forwardPort, err := conn.Forward(ctx, 5555)
217+
if err != nil {
218+
t.Errorf("Forward failed: %v", err)
219+
}
220+
if forwardPort <= 0 || forwardPort > 65535 {
221+
t.Fatalf("invalid port: %d", forwardPort)
222+
}
223+
adb_forwarded_endpoint := fmt.Sprintf("localhost:%s", strconv.Itoa(forwardPort))
224+
225+
out, err := exec.Command("adb", "connect", adb_forwarded_endpoint).CombinedOutput()
226+
require.NoError(t, err, "adb connect output: %q", out)
227+
228+
cmd := exec.Command("adb", "-s", adb_forwarded_endpoint, "shell", "echo", "Hello, World!")
229+
out, err = cmd.CombinedOutput()
230+
require.NoError(t, err, "command output: %q", out)
231+
feedback.Printf("Command output:\n%s\n", string(out))
232+
require.NotNil(t, string(out))
233+
234+
err = conn.ForwardKillAll(t.Context())
235+
require.NoError(t, err)
236+
out, err = exec.Command("adb", "disconnect", adb_forwarded_endpoint).CombinedOutput()
237+
require.NoError(t, err, "adb disconnect output: %q", out)
238+
239+
out, err = exec.Command("adb", "connect", adb_forwarded_endpoint).CombinedOutput()
240+
require.NoError(t, err, "adb connect output: %q", out)
241+
242+
cmd = exec.Command("adb", "-s", adb_forwarded_endpoint, "shell", "echo", "Hello, World!")
243+
out, err = cmd.CombinedOutput()
244+
require.Error(t, err, "command output: %q", out)
245+
feedback.Printf("Command output:\n%s\n", string(out))
246+
})
165247
}

pkg/board/remote/ssh/ssh.go

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,31 @@ package ssh
33
import (
44
"bytes"
55
"context"
6+
"errors"
67
"fmt"
78
"io"
89
"io/fs"
910
"log"
11+
"log/slog"
12+
"net"
1013
"path"
1114
"strings"
15+
"sync"
1216

1317
"golang.org/x/crypto/ssh"
1418

1519
"github.com/bcmi-labs/orchestrator/pkg/board/remote"
20+
"github.com/bcmi-labs/orchestrator/pkg/board/remote/common"
1621
)
1722

23+
var ErrAuthFailed = errors.New("ssh authentication failed")
24+
1825
type SSHConnection struct {
1926
client *ssh.Client
27+
wg sync.WaitGroup
28+
29+
mu sync.Mutex
30+
Listeners []net.Listener
2031
}
2132

2233
// Ensures SSHConnection implements the RemoteConn interface at compile time.
@@ -32,7 +43,13 @@ func FromHost(user, password, address string) (*SSHConnection, error) {
3243
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // nolint:gosec
3344
})
3445
if err != nil {
35-
log.Fatalf("Failed to dial: %s", err)
46+
msg := err.Error()
47+
if strings.Contains(msg, "unable to authenticate") ||
48+
strings.Contains(msg, "no supported methods remain") ||
49+
strings.Contains(msg, "permission denied") {
50+
return nil, ErrAuthFailed
51+
}
52+
return nil, fmt.Errorf("failed to dial SSH: %w", err)
3653
}
3754

3855
return &SSHConnection{
@@ -41,11 +58,82 @@ func FromHost(user, password, address string) (*SSHConnection, error) {
4158
}
4259

4360
func (a *SSHConnection) Forward(ctx context.Context, remotePort int) (int, error) {
44-
panic("`Forward` is not implemented for SSHConnection")
61+
62+
// Get a random available port for remote connection
63+
randomPort, err := common.GetAvailablePort()
64+
if err != nil {
65+
log.Printf("failed to get available port: %v", err)
66+
return 0, err
67+
}
68+
// Listen locally on remote port
69+
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "localhost", randomPort))
70+
if err != nil {
71+
return 0, err
72+
}
73+
74+
a.mu.Lock()
75+
a.Listeners = append(a.Listeners, listener)
76+
a.mu.Unlock()
77+
78+
a.wg.Add(1)
79+
go func() {
80+
defer listener.Close()
81+
defer a.wg.Done()
82+
83+
for {
84+
localConn, err := listener.Accept()
85+
if err != nil {
86+
if !errors.Is(err, net.ErrClosed) {
87+
slog.Warn("failed to accept local connection:", slog.Any("error", err))
88+
}
89+
return
90+
}
91+
92+
go func(localConn net.Conn, remotePort int) {
93+
defer localConn.Close()
94+
95+
// TODO: the kill operation should forcefully terminate the connection that was already estabish
96+
97+
// Open remote connection through SSH
98+
remoteConn, err := a.client.Dial("tcp", fmt.Sprintf("localhost:%d", remotePort))
99+
if err != nil {
100+
slog.Warn("failed to dial remote host:", slog.Any("error", err))
101+
return
102+
103+
}
104+
defer remoteConn.Close()
105+
106+
// Bidirectional copy
107+
var wg sync.WaitGroup
108+
wg.Go(func() { copyAndLog(remoteConn, localConn) })
109+
wg.Go(func() { copyAndLog(localConn, remoteConn) })
110+
wg.Wait()
111+
}(localConn, remotePort)
112+
}
113+
}()
114+
115+
return randomPort, nil
116+
117+
}
118+
119+
func copyAndLog(dst io.Writer, src io.Reader) {
120+
_, err := io.Copy(dst, src)
121+
if err != nil {
122+
slog.Warn("failed to copy connection", slog.Any("error", err))
123+
}
45124
}
46125

47126
func (a *SSHConnection) ForwardKillAll(ctx context.Context) error {
48-
panic("`ForwardKillAll` is not implemented for SSHConnection")
127+
a.mu.Lock()
128+
defer a.mu.Unlock()
129+
for _, listener := range a.Listeners {
130+
if err := listener.Close(); err != nil {
131+
return err
132+
}
133+
}
134+
a.wg.Wait()
135+
a.Listeners = make([]net.Listener, 0)
136+
return nil
49137
}
50138

51139
func (a *SSHConnection) List(path string) ([]remote.FileInfo, error) {

0 commit comments

Comments
 (0)