Skip to content

Commit 8108012

Browse files
committed
portfwd: support HostSocket
Signed-off-by: Norio Nomura <norio.nomura@gmail.com> portfwd: remove "unixgram" forwarding code because that does not work Signed-off-by: Norio Nomura <norio.nomura@gmail.com> portfwd: do not use `listenConfig` param on Unix domain sockets Signed-off-by: Norio Nomura <norio.nomura@gmail.com>
1 parent 331bd2d commit 8108012

File tree

6 files changed

+57
-13
lines changed

6 files changed

+57
-13
lines changed

pkg/limayaml/defaults.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -902,14 +902,6 @@ func FillPortForwardDefaults(rule *limatype.PortForward, instDir string, user li
902902
rule.GuestPortRange[1] = rule.GuestPort
903903
}
904904
}
905-
if rule.HostPortRange[0] == 0 && rule.HostPortRange[1] == 0 {
906-
if rule.HostPort == 0 {
907-
rule.HostPortRange = rule.GuestPortRange
908-
} else {
909-
rule.HostPortRange[0] = rule.HostPort
910-
rule.HostPortRange[1] = rule.HostPort
911-
}
912-
}
913905
if rule.GuestSocket != "" {
914906
if out, err := executeGuestTemplate(rule.GuestSocket, instDir, user, param); err == nil {
915907
rule.GuestSocket = out.String()
@@ -926,6 +918,13 @@ func FillPortForwardDefaults(rule *limatype.PortForward, instDir string, user li
926918
if !filepath.IsAbs(rule.HostSocket) {
927919
rule.HostSocket = filepath.Join(instDir, filenames.SocketDir, rule.HostSocket)
928920
}
921+
} else if rule.HostPortRange[0] == 0 && rule.HostPortRange[1] == 0 {
922+
if rule.HostPort == 0 {
923+
rule.HostPortRange = rule.GuestPortRange
924+
} else {
925+
rule.HostPortRange[0] = rule.HostPort
926+
rule.HostPortRange[1] = rule.HostPort
927+
}
929928
}
930929
}
931930

pkg/limayaml/defaults_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ func TestFillDefault(t *testing.T) {
268268
expect.PortForwards[2].HostPort = 8888
269269
expect.PortForwards[2].HostPortRange = [2]int{8888, 8888}
270270

271+
expect.PortForwards[3].HostPortRange = [2]int{0, 0}
271272
expect.PortForwards[3].GuestSocket = fmt.Sprintf("%s | %s | %s | %s", user.HomeDir, user.Uid, user.Username, y.Param["ONE"])
272273
expect.PortForwards[3].HostSocket = fmt.Sprintf("%s | %s | %s | %s | %s | %s", hostHome, instDir, instName, currentUser.Uid, currentUser.Username, y.Param["ONE"])
273274

pkg/limayaml/validate.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,10 @@ func Validate(y *limatype.LimaYAML, warn bool) error {
314314
if err := validatePort(fmt.Sprintf("%s.guestPortRange[%d]", field, j), rule.GuestPortRange[j]); err != nil {
315315
errs = errors.Join(errs, err)
316316
}
317-
if err := validatePort(fmt.Sprintf("%s.hostPortRange[%d]", field, j), rule.HostPortRange[j]); err != nil {
318-
errs = errors.Join(errs, err)
317+
if rule.HostSocket == "" {
318+
if err := validatePort(fmt.Sprintf("%s.hostPortRange[%d]", field, j), rule.HostPortRange[j]); err != nil {
319+
errs = errors.Join(errs, err)
320+
}
319321
}
320322
}
321323
if rule.GuestPortRange[0] > rule.GuestPortRange[1] {
@@ -324,9 +326,6 @@ func Validate(y *limatype.LimaYAML, warn bool) error {
324326
if rule.HostPortRange[0] > rule.HostPortRange[1] {
325327
errs = errors.Join(errs, fmt.Errorf("field `%s.hostPortRange[1]` must be greater than or equal to field `%s.hostPortRange[0]`", field, field))
326328
}
327-
if rule.GuestPortRange[1]-rule.GuestPortRange[0] != rule.HostPortRange[1]-rule.HostPortRange[0] {
328-
errs = errors.Join(errs, fmt.Errorf("field `%s.hostPortRange` must specify the same number of ports as field `%s.guestPortRange`", field, field))
329-
}
330329
if rule.GuestSocket != "" {
331330
if !path.IsAbs(rule.GuestSocket) {
332331
errs = errors.Join(errs, fmt.Errorf("field `%s.guestSocket` must be an absolute path, but is %q", field, rule.GuestSocket))
@@ -343,7 +342,10 @@ func Validate(y *limatype.LimaYAML, warn bool) error {
343342
if rule.GuestSocket == "" && rule.GuestPortRange[1]-rule.GuestPortRange[0] > 0 {
344343
errs = errors.Join(errs, fmt.Errorf("field `%s.hostSocket` can only be mapped from a single port or socket. not a range", field))
345344
}
345+
} else if rule.GuestPortRange[1]-rule.GuestPortRange[0] != rule.HostPortRange[1]-rule.HostPortRange[0] {
346+
errs = errors.Join(errs, fmt.Errorf("field `%s.hostPortRange` must specify the same number of ports as field `%s.guestPortRange`", field, field))
346347
}
348+
347349
if len(rule.HostSocket) >= osutil.UnixPathMax {
348350
errs = errors.Join(errs, fmt.Errorf("field `%s.hostSocket` must be less than UNIX_PATH_MAX=%d characters, but is %d",
349351
field, osutil.UnixPathMax, len(rule.HostSocket)))

pkg/portfwd/listener.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"errors"
99
"fmt"
1010
"net"
11+
"os"
12+
"path/filepath"
1113
"strings"
1214
"sync"
1315

@@ -146,3 +148,13 @@ func (p *ClosableListeners) forwardUDP(ctx context.Context, client *guestagentcl
146148
func key(protocol, hostAddress, guestAddress string) string {
147149
return fmt.Sprintf("%s-%s-%s", protocol, hostAddress, guestAddress)
148150
}
151+
152+
func prepareUnixSocket(hostSocket string) error {
153+
if err := os.RemoveAll(hostSocket); err != nil {
154+
return fmt.Errorf("can't clean up %q: %w", hostSocket, err)
155+
}
156+
if err := os.MkdirAll(filepath.Dir(hostSocket), 0o755); err != nil {
157+
return fmt.Errorf("can't create directory for local socket %q: %w", hostSocket, err)
158+
}
159+
return nil
160+
}

pkg/portfwd/listener_darwin.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,26 @@ import (
77
"context"
88
"fmt"
99
"net"
10+
"path/filepath"
1011
"strconv"
1112

1213
"github.com/sirupsen/logrus"
1314
)
1415

1516
func Listen(ctx context.Context, listenConfig net.ListenConfig, hostAddress string) (net.Listener, error) {
17+
if filepath.IsAbs(hostAddress) {
18+
// Handle Unix domain sockets
19+
if err := prepareUnixSocket(hostAddress); err != nil {
20+
return nil, err
21+
}
22+
var lc net.ListenConfig
23+
unixLis, err := lc.Listen(ctx, "unix", hostAddress)
24+
if err != nil {
25+
logrus.WithError(err).Errorf("failed to listen unix: %v", hostAddress)
26+
return nil, err
27+
}
28+
return unixLis, nil
29+
}
1630
localIPStr, localPortStr, _ := net.SplitHostPort(hostAddress)
1731
localIP := net.ParseIP(localIPStr)
1832
localPort, _ := strconv.Atoi(localPortStr)

pkg/portfwd/listener_others.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,25 @@ package portfwd
88
import (
99
"context"
1010
"net"
11+
"path/filepath"
12+
13+
"github.com/sirupsen/logrus"
1114
)
1215

1316
func Listen(ctx context.Context, listenConfig net.ListenConfig, hostAddress string) (net.Listener, error) {
17+
if filepath.IsAbs(hostAddress) {
18+
// Handle Unix domain sockets
19+
if err := prepareUnixSocket(hostAddress); err != nil {
20+
return nil, err
21+
}
22+
var lc net.ListenConfig
23+
unixLis, err := lc.Listen(ctx, "unix", hostAddress)
24+
if err != nil {
25+
logrus.WithError(err).Errorf("failed to listen unix: %v", hostAddress)
26+
return nil, err
27+
}
28+
return unixLis, nil
29+
}
1430
return listenConfig.Listen(ctx, "tcp", hostAddress)
1531
}
1632

0 commit comments

Comments
 (0)