Skip to content

Commit 02c6d29

Browse files
authored
Refactor into separate packages & add tests. (path-network#4)
1 parent 201d278 commit 02c6d29

File tree

14 files changed

+902
-324
lines changed

14 files changed

+902
-324
lines changed

.github/workflows/test-startup.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Go
1+
name: Test systemd
22

33
on:
44
push:
@@ -18,7 +18,7 @@ jobs:
1818
go-version: "1.21"
1919

2020
- name: Build
21-
run: go build -v ./...
21+
run: go build -v
2222

2323
- name: Install go-mmproxy
2424
run: |

.github/workflows/test.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: Test
2+
3+
on:
4+
push:
5+
branches: ["main"]
6+
pull_request:
7+
branches: ["main"]
8+
9+
jobs:
10+
build:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
15+
- name: Set up Go
16+
uses: actions/setup-go@v4
17+
with:
18+
go-version: "1.21"
19+
20+
- name: Build
21+
run: go build -v
22+
23+
- name: Prepare ip routes
24+
run: |
25+
sudo ip rule add from 127.0.0.1/8 iif lo table 123
26+
sudo ip route add local 0.0.0.0/0 dev lo table 123
27+
sudo ip -6 rule add from ::1/128 iif lo table 123
28+
sudo ip -6 route add local ::/0 dev lo table 123
29+
30+
- name: Test
31+
run: sudo go test -v -timeout 30s ./tests

buffers.go

Lines changed: 0 additions & 24 deletions
This file was deleted.

buffers/buffers.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright 2019 Path Network, Inc. All rights reserved.
2+
// Copyright 2024 Konrad Zemek <konrad.zemek@gmail.com>
3+
// Use of this source code is governed by a BSD-style
4+
// license that can be found in the LICENSE file.
5+
6+
package buffers
7+
8+
import (
9+
"math"
10+
"sync"
11+
)
12+
13+
var buffers sync.Pool
14+
15+
func init() {
16+
buffers.New = func() any {
17+
slice := make([]byte, math.MaxUint16)
18+
return &slice
19+
}
20+
}
21+
22+
func Get() []byte {
23+
return *buffers.Get().(*[]byte)
24+
}
25+
26+
func Put(buf []byte) {
27+
buffers.Put(&buf)
28+
}

main.go

Lines changed: 72 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,58 @@
11
// Copyright 2019 Path Network, Inc. All rights reserved.
2+
// Copyright 2024 Konrad Zemek <konrad.zemek@gmail.com>
23
// Use of this source code is governed by a BSD-style
34
// license that can be found in the LICENSE file.
45

56
package main
67

78
import (
89
"bufio"
10+
"context"
911
"flag"
1012
"log/slog"
1113
"net"
1214
"net/netip"
1315
"os"
1416
"syscall"
1517
"time"
18+
19+
"github.com/kzemek/go-mmproxy/tcp"
20+
"github.com/kzemek/go-mmproxy/udp"
21+
"github.com/kzemek/go-mmproxy/utils"
1622
)
1723

18-
type options struct {
19-
Protocol string
20-
ListenAddrStr string
21-
TargetAddr4Str string
22-
TargetAddr6Str string
23-
ListenAddr netip.AddrPort
24-
TargetAddr4 netip.AddrPort
25-
TargetAddr6 netip.AddrPort
26-
Mark int
27-
Verbose int
28-
allowedSubnetsPath string
29-
AllowedSubnets []*net.IPNet
30-
Listeners int
31-
Logger *slog.Logger
32-
udpCloseAfter int
33-
UDPCloseAfter time.Duration
34-
}
24+
var protocolStr string
25+
var listenAddrStr string
26+
var targetAddr4Str string
27+
var targetAddr6Str string
28+
var allowedSubnetsPath string
29+
var udpCloseAfterInt int
30+
var listeners int
3531

36-
var Opts options
32+
var opts utils.Options
3733

3834
func init() {
39-
flag.StringVar(&Opts.Protocol, "p", "tcp", "Protocol that will be proxied: tcp, udp")
40-
flag.StringVar(&Opts.ListenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on")
41-
flag.StringVar(&Opts.TargetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to")
42-
flag.StringVar(&Opts.TargetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to")
43-
flag.IntVar(&Opts.Mark, "mark", 0, "The mark that will be set on outbound packets")
44-
flag.IntVar(&Opts.Verbose, "v", 0, `0 - no logging of individual connections
35+
flag.StringVar(&protocolStr, "p", "tcp", "Protocol that will be proxied: tcp, udp")
36+
flag.StringVar(&listenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on")
37+
flag.StringVar(&targetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to")
38+
flag.StringVar(&targetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to")
39+
flag.IntVar(&opts.Mark, "mark", 0, "The mark that will be set on outbound packets")
40+
flag.IntVar(&opts.Verbose, "v", 0, `0 - no logging of individual connections
4541
1 - log errors occurring in individual connections
4642
2 - log all state changes of individual connections`)
47-
flag.StringVar(&Opts.allowedSubnetsPath, "allowed-subnets", "",
43+
flag.StringVar(&allowedSubnetsPath, "allowed-subnets", "",
4844
"Path to a file that contains allowed subnets of the proxy servers")
49-
flag.IntVar(&Opts.Listeners, "listeners", 1,
45+
flag.IntVar(&listeners, "listeners", 1,
5046
"Number of listener sockets that will be opened for the listen address (Linux 3.9+)")
51-
flag.IntVar(&Opts.udpCloseAfter, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up")
47+
flag.IntVar(&udpCloseAfterInt, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up")
5248
}
5349

54-
func listen(listenerNum int, errors chan<- error) {
55-
logger := Opts.Logger.With(slog.Int("listenerNum", listenerNum),
56-
slog.String("protocol", Opts.Protocol), slog.String("listenAdr", Opts.ListenAddr.String()))
50+
func listen(ctx context.Context, listenerNum int, parentLogger *slog.Logger, listenErrors chan<- error) {
51+
logger := parentLogger.With(slog.Int("listenerNum", listenerNum),
52+
slog.String("protocol", protocolStr), slog.String("listenAdr", opts.ListenAddr.String()))
5753

5854
listenConfig := net.ListenConfig{}
59-
if Opts.Listeners > 1 {
55+
if listeners > 1 {
6056
listenConfig.Control = func(network, address string, c syscall.RawConn) error {
6157
return c.Control(func(fd uintptr) {
6258
soReusePort := 15
@@ -67,15 +63,15 @@ func listen(listenerNum int, errors chan<- error) {
6763
}
6864
}
6965

70-
if Opts.Protocol == "tcp" {
71-
tcpListen(&listenConfig, logger, errors)
66+
if opts.Protocol == utils.TCP {
67+
tcp.Listen(ctx, &listenConfig, &opts, logger, listenErrors)
7268
} else {
73-
udpListen(&listenConfig, logger, errors)
69+
udp.Listen(ctx, &listenConfig, &opts, logger, listenErrors)
7470
}
7571
}
7672

77-
func loadAllowedSubnets() error {
78-
file, err := os.Open(Opts.allowedSubnetsPath)
73+
func loadAllowedSubnets(logger *slog.Logger) error {
74+
file, err := os.Open(allowedSubnetsPath)
7975
if err != nil {
8076
return err
8177
}
@@ -84,12 +80,12 @@ func loadAllowedSubnets() error {
8480

8581
scanner := bufio.NewScanner(file)
8682
for scanner.Scan() {
87-
_, ipNet, err := net.ParseCIDR(scanner.Text())
83+
ipNet, err := netip.ParsePrefix(scanner.Text())
8884
if err != nil {
8985
return err
9086
}
91-
Opts.AllowedSubnets = append(Opts.AllowedSubnets, ipNet)
92-
Opts.Logger.Info("allowed subnet", slog.String("subnet", ipNet.String()))
87+
opts.AllowedSubnets = append(opts.AllowedSubnets, ipNet)
88+
logger.Info("allowed subnet", slog.String("subnet", ipNet.String()))
9389
}
9490

9591
return nil
@@ -98,72 +94,79 @@ func loadAllowedSubnets() error {
9894
func main() {
9995
flag.Parse()
10096
lvl := slog.LevelInfo
101-
if Opts.Verbose > 0 {
97+
if opts.Verbose > 0 {
10298
lvl = slog.LevelDebug
10399
}
104-
Opts.Logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: lvl}))
105100

106-
if Opts.allowedSubnetsPath != "" {
107-
if err := loadAllowedSubnets(); err != nil {
108-
Opts.Logger.Error("failed to load allowed subnets file", "path", Opts.allowedSubnetsPath, "error", err)
101+
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: lvl}))
102+
103+
if allowedSubnetsPath != "" {
104+
if err := loadAllowedSubnets(logger); err != nil {
105+
logger.Error("failed to load allowed subnets file", "path", allowedSubnetsPath, "error", err)
109106
}
110107
}
111108

112-
if Opts.Protocol != "tcp" && Opts.Protocol != "udp" {
113-
Opts.Logger.Error("--protocol has to be one of udp, tcp", slog.String("protocol", Opts.Protocol))
109+
if protocolStr == "tcp" {
110+
opts.Protocol = utils.TCP
111+
} else if protocolStr == "udp" {
112+
opts.Protocol = utils.UDP
113+
} else {
114+
logger.Error("--protocol has to be one of udp, tcp", slog.String("protocol", protocolStr))
114115
os.Exit(1)
115116
}
116117

117-
if Opts.Mark < 0 {
118-
Opts.Logger.Error("--mark has to be >= 0", slog.Int("mark", Opts.Mark))
118+
if opts.Mark < 0 {
119+
logger.Error("--mark has to be >= 0", slog.Int("mark", opts.Mark))
119120
os.Exit(1)
120121
}
121122

122-
if Opts.Verbose < 0 {
123-
Opts.Logger.Error("-v has to be >= 0", slog.Int("verbose", Opts.Verbose))
123+
if opts.Verbose < 0 {
124+
logger.Error("-v has to be >= 0", slog.Int("verbose", opts.Verbose))
124125
os.Exit(1)
125126
}
126127

127-
if Opts.Listeners < 1 {
128-
Opts.Logger.Error("--listeners has to be >= 1")
128+
if listeners < 1 {
129+
logger.Error("--listeners has to be >= 1")
129130
os.Exit(1)
130131
}
131132

132133
var err error
133-
if Opts.ListenAddr, err = parseHostPort(Opts.ListenAddrStr); err != nil {
134-
Opts.Logger.Error("listen address is malformed", "error", err)
134+
if opts.ListenAddr, err = utils.ParseHostPort(listenAddrStr); err != nil {
135+
logger.Error("listen address is malformed", "error", err)
135136
os.Exit(1)
136137
}
137138

138-
if Opts.TargetAddr4, err = netip.ParseAddrPort(Opts.TargetAddr4Str); err != nil {
139-
Opts.Logger.Error("ipv4 target address is malformed", "error", err)
139+
if opts.TargetAddr4, err = netip.ParseAddrPort(targetAddr4Str); err != nil {
140+
logger.Error("ipv4 target address is malformed", "error", err)
140141
os.Exit(1)
141142
}
142-
if !Opts.TargetAddr4.Addr().Is4() {
143-
Opts.Logger.Error("ipv4 target address is not IPv4")
143+
if !opts.TargetAddr4.Addr().Is4() {
144+
logger.Error("ipv4 target address is not IPv4")
144145
os.Exit(1)
145146
}
146147

147-
if Opts.TargetAddr6, err = netip.ParseAddrPort(Opts.TargetAddr6Str); err != nil {
148-
Opts.Logger.Error("ipv6 target address is malformed", "error", err)
148+
if opts.TargetAddr6, err = netip.ParseAddrPort(targetAddr6Str); err != nil {
149+
logger.Error("ipv6 target address is malformed", "error", err)
149150
os.Exit(1)
150151
}
151-
if !Opts.TargetAddr6.Addr().Is6() {
152-
Opts.Logger.Error("ipv6 target address is not IPv6")
152+
if !opts.TargetAddr6.Addr().Is6() {
153+
logger.Error("ipv6 target address is not IPv6")
153154
os.Exit(1)
154155
}
155156

156-
if Opts.udpCloseAfter < 0 {
157-
Opts.Logger.Error("--close-after has to be >= 0", slog.Int("close-after", Opts.udpCloseAfter))
157+
if udpCloseAfterInt < 0 {
158+
logger.Error("--close-after has to be >= 0", slog.Int("close-after", udpCloseAfterInt))
158159
os.Exit(1)
159160
}
160-
Opts.UDPCloseAfter = time.Duration(Opts.udpCloseAfter) * time.Second
161+
opts.UDPCloseAfter = time.Duration(udpCloseAfterInt) * time.Second
161162

162-
listenErrors := make(chan error, Opts.Listeners)
163-
for i := 0; i < Opts.Listeners; i++ {
164-
go listen(i, listenErrors)
163+
listenErrors := make(chan error, listeners)
164+
ctxs := make([]context.Context, listeners)
165+
for i := range ctxs {
166+
ctxs[i] = context.Background()
167+
go listen(ctxs[i], i, logger, listenErrors)
165168
}
166-
for i := 0; i < Opts.Listeners; i++ {
169+
for range ctxs {
167170
<-listenErrors
168171
}
169172
}

0 commit comments

Comments
 (0)