Skip to content

Commit df2d206

Browse files
authored
fix: websocket origin check (fixes serial monitor on Windows) (#39)
* Fix websocket origin check * Test allowed origins on startup.
1 parent d8a3605 commit df2d206

File tree

2 files changed

+90
-10
lines changed

2 files changed

+90
-10
lines changed

internal/api/handlers/monitor.go

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,24 +83,53 @@ func monitorStream(mon net.Conn, ws *websocket.Conn) {
8383
}()
8484
}
8585

86+
func splitOrigin(origin string) (scheme, host, port string, err error) {
87+
parts := strings.SplitN(origin, "://", 2)
88+
if len(parts) != 2 {
89+
return "", "", "", fmt.Errorf("invalid origin format: %s", origin)
90+
}
91+
scheme = parts[0]
92+
hostPort := parts[1]
93+
hostParts := strings.SplitN(hostPort, ":", 2)
94+
host = hostParts[0]
95+
if len(hostParts) == 2 {
96+
port = hostParts[1]
97+
} else {
98+
port = "*"
99+
}
100+
return scheme, host, port, nil
101+
}
102+
86103
func checkOrigin(origin string, allowedOrigins []string) bool {
104+
scheme, host, port, err := splitOrigin(origin)
105+
if err != nil {
106+
slog.Error("WebSocket origin check failed", slog.String("origin", origin), slog.String("error", err.Error()))
107+
return false
108+
}
87109
for _, allowed := range allowedOrigins {
88-
if strings.HasSuffix(allowed, "*") {
89-
// String ends with *, match the prefix
90-
if strings.HasPrefix(origin, strings.TrimSuffix(allowed, "*")) {
91-
return true
92-
}
93-
} else {
94-
// Exact match
95-
if allowed == origin {
96-
return true
97-
}
110+
allowedScheme, allowedHost, allowedPort, err := splitOrigin(allowed)
111+
if err != nil {
112+
panic(err)
113+
}
114+
if allowedScheme != scheme {
115+
continue
98116
}
117+
if allowedHost != host && allowedHost != "*" {
118+
continue
119+
}
120+
if allowedPort != port && allowedPort != "*" {
121+
continue
122+
}
123+
return true
99124
}
125+
slog.Error("WebSocket origin check failed", slog.String("origin", origin))
100126
return false
101127
}
102128

103129
func HandleMonitorWS(allowedOrigins []string) http.HandlerFunc {
130+
// Do a dry-run of checkorigin, so it can panic if misconfigured now, not on first request
131+
_ = checkOrigin("http://example.com:8000", allowedOrigins)
132+
104133
upgrader := websocket.Upgrader{
105134
ReadBufferSize: 1024,
106135
WriteBufferSize: 1024,
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// This file is part of arduino-app-cli.
2+
//
3+
// Copyright 2025 ARDUINO SA (http://www.arduino.cc/)
4+
//
5+
// This software is released under the GNU General Public License version 3,
6+
// which covers the main part of arduino-app-cli.
7+
// The terms of this license can be found at:
8+
// https://www.gnu.org/licenses/gpl-3.0.en.html
9+
//
10+
// You can be released from the requirements of the above licenses by purchasing
11+
// a commercial license. Buying such a license is mandatory if you want to
12+
// modify or otherwise use the software for commercial activities involving the
13+
// Arduino software without disclosing the source code of your own applications.
14+
// To purchase a commercial license, send an email to license@arduino.cc.
15+
16+
package handlers
17+
18+
import (
19+
"testing"
20+
21+
"github.com/stretchr/testify/require"
22+
)
23+
24+
func TestCheckOrigin(t *testing.T) {
25+
origins := []string{
26+
"wails://wails",
27+
"wails://wails.localhost:*",
28+
"http://wails.localhost:*",
29+
"http://localhost:*",
30+
"https://localhost:*",
31+
"http://example.com:7000",
32+
"https://*:443",
33+
}
34+
35+
allow := func(origin string) {
36+
require.True(t, checkOrigin(origin, origins), "Expected origin %s to be allowed", origin)
37+
}
38+
deny := func(origin string) {
39+
require.False(t, checkOrigin(origin, origins), "Expected origin %s to be denied", origin)
40+
}
41+
allow("wails://wails")
42+
allow("wails://wails:8000")
43+
allow("http://wails.localhost")
44+
allow("http://example.com:7000")
45+
allow("https://blah.com:443")
46+
deny("wails://evil.com")
47+
deny("https://wails.localhost:8000")
48+
deny("http://example.com:8000")
49+
deny("http://blah.com:443")
50+
deny("https://blah.com:8080")
51+
}

0 commit comments

Comments
 (0)