Skip to content

Commit e98392b

Browse files
authored
Merge pull request #177 from dackroyd/invalidate-sql-rawbytes-memory
Invalidate memory scanned into sql.RawBytes
2 parents 6c8a572 + d5879ee commit e98392b

File tree

4 files changed

+486
-0
lines changed

4 files changed

+486
-0
lines changed

rows.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package sqlmock
22

33
import (
4+
"bytes"
45
"database/sql/driver"
56
"encoding/csv"
67
"fmt"
78
"io"
89
"strings"
910
)
1011

12+
const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ "
13+
1114
// CSVColumnParser is a function which converts trimmed csv
1215
// column string to a []byte representation. Currently
1316
// transforms NULL to nil
@@ -23,13 +26,15 @@ type rowSets struct {
2326
sets []*Rows
2427
pos int
2528
ex *ExpectedQuery
29+
raw [][]byte
2630
}
2731

2832
func (rs *rowSets) Columns() []string {
2933
return rs.sets[rs.pos].cols
3034
}
3135

3236
func (rs *rowSets) Close() error {
37+
rs.invalidateRaw()
3338
rs.ex.rowsWereClosed = true
3439
return rs.sets[rs.pos].closeErr
3540
}
@@ -38,11 +43,17 @@ func (rs *rowSets) Close() error {
3843
func (rs *rowSets) Next(dest []driver.Value) error {
3944
r := rs.sets[rs.pos]
4045
r.pos++
46+
rs.invalidateRaw()
4147
if r.pos > len(r.rows) {
4248
return io.EOF // per interface spec
4349
}
4450

4551
for i, col := range r.rows[r.pos-1] {
52+
if b, ok := rawBytes(col); ok {
53+
rs.raw = append(rs.raw, b)
54+
dest[i] = b
55+
continue
56+
}
4657
dest[i] = col
4758
}
4859

@@ -80,6 +91,30 @@ func (rs *rowSets) empty() bool {
8091
return true
8192
}
8293

94+
func rawBytes(col driver.Value) (_ []byte, ok bool) {
95+
val, ok := col.([]byte)
96+
if !ok || len(val) == 0 {
97+
return nil, false
98+
}
99+
// Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later
100+
// This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close()
101+
b := make([]byte, len(val))
102+
copy(b, val)
103+
return b, true
104+
}
105+
106+
// Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close.
107+
// If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes
108+
func (rs *rowSets) invalidateRaw() {
109+
// Replace the content of slices previously returned
110+
b := []byte(invalidate)
111+
for _, r := range rs.raw {
112+
copy(r, bytes.Repeat(b, len(r)/len(b)+1))
113+
}
114+
// Start with new slices for the next scan
115+
rs.raw = nil
116+
}
117+
83118
// Rows is a mocked collection of rows to
84119
// return for Query result
85120
type Rows struct {

rows_go13_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// +build go1.3
2+
3+
package sqlmock
4+
5+
import (
6+
"database/sql"
7+
"testing"
8+
)
9+
10+
func TestQueryRowBytesNotInvalidatedByNext_stringIntoRawBytes(t *testing.T) {
11+
t.Parallel()
12+
rows := NewRows([]string{"raw"}).
13+
AddRow(`one binary value with some text!`).
14+
AddRow(`two binary value with even more text than the first one`)
15+
scan := func(rs *sql.Rows) ([]byte, error) {
16+
var raw sql.RawBytes
17+
return raw, rs.Scan(&raw)
18+
}
19+
want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
20+
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
21+
}
22+
23+
func TestQueryRowBytesNotInvalidatedByClose_stringIntoRawBytes(t *testing.T) {
24+
t.Parallel()
25+
rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`)
26+
scan := func(rs *sql.Rows) ([]byte, error) {
27+
var raw sql.RawBytes
28+
return raw, rs.Scan(&raw)
29+
}
30+
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
31+
}

rows_go18_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
package sqlmock
44

55
import (
6+
"database/sql"
7+
"encoding/json"
68
"fmt"
79
"testing"
810
)
@@ -90,3 +92,114 @@ func TestQueryMultiRows(t *testing.T) {
9092
t.Errorf("there were unfulfilled expectations: %s", err)
9193
}
9294
}
95+
96+
func TestQueryRowBytesInvalidatedByNext_jsonRawMessageIntoRawBytes(t *testing.T) {
97+
t.Parallel()
98+
replace := []byte(invalid)
99+
rows := NewRows([]string{"raw"}).
100+
AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
101+
AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
102+
scan := func(rs *sql.Rows) ([]byte, error) {
103+
var raw sql.RawBytes
104+
return raw, rs.Scan(&raw)
105+
}
106+
want := []struct {
107+
Initial []byte
108+
Replaced []byte
109+
}{
110+
{Initial: []byte(`{"thing": "one", "thing2": "two"}`), Replaced: replace[:len(replace)-6]},
111+
{Initial: []byte(`{"that": "foo", "this": "bar"}`), Replaced: replace[:len(replace)-9]},
112+
}
113+
queryRowBytesInvalidatedByNext(t, rows, scan, want)
114+
}
115+
116+
func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoBytes(t *testing.T) {
117+
t.Parallel()
118+
rows := NewRows([]string{"raw"}).
119+
AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
120+
AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
121+
scan := func(rs *sql.Rows) ([]byte, error) {
122+
var b []byte
123+
return b, rs.Scan(&b)
124+
}
125+
want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
126+
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
127+
}
128+
129+
func TestQueryRowBytesNotInvalidatedByNext_bytesIntoCustomBytes(t *testing.T) {
130+
t.Parallel()
131+
rows := NewRows([]string{"raw"}).
132+
AddRow([]byte(`one binary value with some text!`)).
133+
AddRow([]byte(`two binary value with even more text than the first one`))
134+
scan := func(rs *sql.Rows) ([]byte, error) {
135+
type customBytes []byte
136+
var b customBytes
137+
return b, rs.Scan(&b)
138+
}
139+
want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
140+
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
141+
}
142+
143+
func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoCustomBytes(t *testing.T) {
144+
t.Parallel()
145+
rows := NewRows([]string{"raw"}).
146+
AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
147+
AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
148+
scan := func(rs *sql.Rows) ([]byte, error) {
149+
type customBytes []byte
150+
var b customBytes
151+
return b, rs.Scan(&b)
152+
}
153+
want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
154+
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
155+
}
156+
157+
func TestQueryRowBytesNotInvalidatedByClose_bytesIntoCustomBytes(t *testing.T) {
158+
t.Parallel()
159+
rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`))
160+
scan := func(rs *sql.Rows) ([]byte, error) {
161+
type customBytes []byte
162+
var b customBytes
163+
return b, rs.Scan(&b)
164+
}
165+
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
166+
}
167+
168+
func TestQueryRowBytesInvalidatedByClose_jsonRawMessageIntoRawBytes(t *testing.T) {
169+
t.Parallel()
170+
replace := []byte(invalid)
171+
rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
172+
scan := func(rs *sql.Rows) ([]byte, error) {
173+
var raw sql.RawBytes
174+
return raw, rs.Scan(&raw)
175+
}
176+
want := struct {
177+
Initial []byte
178+
Replaced []byte
179+
}{
180+
Initial: []byte(`{"thing": "one", "thing2": "two"}`),
181+
Replaced: replace[:len(replace)-6],
182+
}
183+
queryRowBytesInvalidatedByClose(t, rows, scan, want)
184+
}
185+
186+
func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoBytes(t *testing.T) {
187+
t.Parallel()
188+
rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
189+
scan := func(rs *sql.Rows) ([]byte, error) {
190+
var b []byte
191+
return b, rs.Scan(&b)
192+
}
193+
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
194+
}
195+
196+
func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoCustomBytes(t *testing.T) {
197+
t.Parallel()
198+
rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
199+
scan := func(rs *sql.Rows) ([]byte, error) {
200+
type customBytes []byte
201+
var b customBytes
202+
return b, rs.Scan(&b)
203+
}
204+
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
205+
}

0 commit comments

Comments
 (0)