Skip to content

Commit 8b04d0a

Browse files
committed
Convert SSRF tests to inline expectations tests
1 parent 6e4dbe8 commit 8b04d0a

File tree

3 files changed

+34
-32
lines changed

3 files changed

+34
-32
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
query: experimental/CWE-918/SSRF.ql
2-
postprocess: utils/test/PrettyPrintModels.ql
2+
postprocess:
3+
- utils/test/PrettyPrintModels.ql
4+
- utils/test/InlineExpectationsTestQuery.ql

go/ql/test/experimental/CWE-918/builtin.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ import (
1616
)
1717

1818
func handler(w http.ResponseWriter, req *http.Request) {
19-
target := req.FormValue("target")
19+
target := req.FormValue("target") // $ Source
2020

2121
// BAD: `target` is controlled by the attacker
22-
_, err := http.Get("https://" + target + ".example.com/data/")
22+
_, err := http.Get("https://" + target + ".example.com/data/") // $ Alert
2323
if err != nil {
2424
// error handling
2525
}
@@ -80,12 +80,12 @@ func test() {
8080

8181
// x net websocket dial bad
8282
http.HandleFunc("/ex2", func(w http.ResponseWriter, r *http.Request) {
83-
untrustedInput := r.Referer()
83+
untrustedInput := r.Referer() // $ Source
8484

8585
origin := "http://localhost/"
8686

8787
// bad as input is directly passed to dial function
88-
ws, _ := websocket.Dial(untrustedInput, "", origin) // SSRF
88+
ws, _ := websocket.Dial(untrustedInput, "", origin) // $ Alert
8989
var msg = make([]byte, 512)
9090
var n int
9191
n, _ = ws.Read(msg)
@@ -94,12 +94,12 @@ func test() {
9494

9595
// x net websocket dialConfig bad
9696
http.HandleFunc("/ex3", func(w http.ResponseWriter, r *http.Request) {
97-
untrustedInput := r.Referer()
97+
untrustedInput := r.Referer() // $ Source
9898

9999
origin := "http://localhost/"
100100
// bad as input is directly used
101-
config, _ := websocket.NewConfig(untrustedInput, origin) // SSRF
102-
ws2, _ := websocket.DialConfig(config)
101+
config, _ := websocket.NewConfig(untrustedInput, origin) // $ Sink
102+
ws2, _ := websocket.DialConfig(config) // $ Alert
103103
var msg = make([]byte, 512)
104104
var n int
105105
n, _ = ws2.Read(msg)
@@ -108,10 +108,10 @@ func test() {
108108

109109
// gorilla websocket Dialer.Dial bad
110110
http.HandleFunc("/ex6", func(w http.ResponseWriter, r *http.Request) {
111-
untrustedInput := r.Referer()
111+
untrustedInput := r.Referer() // $ Source
112112

113113
dialer := gorilla.Dialer{}
114-
dialer.Dial(untrustedInput, r.Header) //SSRF
114+
dialer.Dial(untrustedInput, r.Header) // $ Alert
115115
})
116116

117117
// gorilla websocket Dialer.Dial good
@@ -126,10 +126,10 @@ func test() {
126126

127127
// gorilla websocket Dialer.DialContext bad
128128
http.HandleFunc("/ex8", func(w http.ResponseWriter, r *http.Request) {
129-
untrustedInput := r.Referer()
129+
untrustedInput := r.Referer() // $ Source
130130

131131
dialer := gorilla.Dialer{}
132-
dialer.DialContext(context.TODO(), untrustedInput, r.Header) //SSRF
132+
dialer.DialContext(context.TODO(), untrustedInput, r.Header) // $ Alert
133133
})
134134

135135
// gorilla websocket Dialer.DialContext good

go/ql/test/experimental/CWE-918/new-tests.go

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,31 @@ func HandlerGin(c *gin.Context) {
2323
safe string `binding:"alphanum"`
2424
}
2525

26-
err := c.ShouldBindJSON(&body)
26+
err := c.ShouldBindJSON(&body) // $ Source
2727

2828
http.Get(fmt.Sprintf("http://example.com/%d", body.integer)) // OK
2929
http.Get(fmt.Sprintf("http://example.com/%v", body.float)) // OK
3030
http.Get(fmt.Sprintf("http://example.com/%v", body.boolean)) // OK
31-
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF
32-
http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // SSRF
31+
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert
32+
http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // $ Alert
3333

3434
if err == nil {
35-
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF
35+
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert
3636
http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // OK
3737
}
3838

39-
taintedParam := c.Param("id")
39+
taintedParam := c.Param("id") // $ Source
4040

4141
validate := validator.New()
4242
err = validate.Var(taintedParam, "alpha")
4343
if err == nil {
4444
http.Get("http://example.com/" + taintedParam) // OK
4545
}
4646

47-
http.Get("http://example.com/" + taintedParam) //SSRF
47+
http.Get("http://example.com/" + taintedParam) // $ Alert
4848

49-
taintedQuery := c.Query("id")
50-
http.Get("http://example.com/" + taintedQuery) //SSRF
49+
taintedQuery := c.Query("id") // $ Source
50+
http.Get("http://example.com/" + taintedQuery) // $ Alert
5151
}
5252

5353
func HandlerHttp(req *http.Request) {
@@ -59,41 +59,41 @@ func HandlerHttp(req *http.Request) {
5959
word string
6060
safe string `validate:"alphanum"`
6161
}
62-
reqBody, _ := ioutil.ReadAll(req.Body)
62+
reqBody, _ := ioutil.ReadAll(req.Body) // $ Source
6363
json.Unmarshal(reqBody, &body)
6464

6565
http.Get(fmt.Sprintf("http://example.com/%d", body.integer)) // OK
6666
http.Get(fmt.Sprintf("http://example.com/%v", body.float)) // OK
6767
http.Get(fmt.Sprintf("http://example.com/%v", body.boolean)) // OK
68-
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF
69-
http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // SSRF
68+
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert
69+
http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // $ Alert
7070

7171
validate := validator.New()
7272
err := validate.Struct(body)
7373
if err == nil {
74-
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF
74+
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert
7575
http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // OK
7676
}
7777

78-
taintedQuery := req.URL.Query().Get("param1")
79-
http.Get("http://example.com/" + taintedQuery) // SSRF
78+
taintedQuery := req.URL.Query().Get("param1") // $ Source
79+
http.Get("http://example.com/" + taintedQuery) // $ Alert
8080

81-
taintedParam := strings.TrimPrefix(req.URL.Path, "/example-path/")
82-
http.Get("http://example.com/" + taintedParam) // SSRF
81+
taintedParam := strings.TrimPrefix(req.URL.Path, "/example-path/") // $ Source
82+
http.Get("http://example.com/" + taintedParam) // $ Alert
8383
}
8484

8585
func HandlerMux(r *http.Request) {
86-
vars := mux.Vars(r)
86+
vars := mux.Vars(r) // $ Source
8787
taintedParam := vars["id"]
88-
http.Get("http://example.com/" + taintedParam) // SSRF
88+
http.Get("http://example.com/" + taintedParam) // $ Alert
8989

9090
numericID, _ := strconv.Atoi(taintedParam)
9191
http.Get(fmt.Sprintf("http://example.com/%d", numericID)) // OK
9292
}
9393

9494
func HandlerChi(r *http.Request) {
95-
taintedParam := chi.URLParam(r, "articleID")
96-
http.Get("http://example.com/" + taintedParam) // SSRF
95+
taintedParam := chi.URLParam(r, "articleID") // $ Source
96+
http.Get("http://example.com/" + taintedParam) // $ Alert
9797

9898
b, _ := strconv.ParseBool(taintedParam)
9999
http.Get(fmt.Sprintf("http://example.com/%t", b)) // OK

0 commit comments

Comments
 (0)