Skip to content

Commit d5879ee

Browse files
committed
Invalidate memory scanned into sql.RawBytes
The intention of sql.RawBytes is for it to hold memory owned by the database. When used, it's content is only valid until the `Next`, `Scan` or `Close` is called on the `Rows` To ensure that we meet this behaviour, when `[]byte` is used in a column, it's value is copied to a buffer that we keep track of for later invalidation. By doing this, incorrect use of `sql.RawBytes` values is exposed in tests that use go-sqlmock. Without this, when a real database is used and it's driver does share memory, then those issues would not be exposed until runtime (and in non-obvious ways)
1 parent 6c8a572 commit d5879ee

File tree

4 files changed

+486
-0
lines changed

4 files changed

+486
-0
lines changed

rows.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package sqlmock
22

33
import (
4+
"bytes"
45
"database/sql/driver"
56
"encoding/csv"
67
"fmt"
78
"io"
89
"strings"
910
)
1011

12+
const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ "
13+
1114
// CSVColumnParser is a function which converts trimmed csv
1215
// column string to a []byte representation. Currently
1316
// transforms NULL to nil
@@ -23,13 +26,15 @@ type rowSets struct {
2326
sets []*Rows
2427
pos int
2528
ex *ExpectedQuery
29+
raw [][]byte
2630
}
2731

2832
func (rs *rowSets) Columns() []string {
2933
return rs.sets[rs.pos].cols
3034
}
3135

3236
func (rs *rowSets) Close() error {
37+
rs.invalidateRaw()
3338
rs.ex.rowsWereClosed = true
3439
return rs.sets[rs.pos].closeErr
3540
}
@@ -38,11 +43,17 @@ func (rs *rowSets) Close() error {
3843
func (rs *rowSets) Next(dest []driver.Value) error {
3944
r := rs.sets[rs.pos]
4045
r.pos++
46+
rs.invalidateRaw()
4147
if r.pos > len(r.rows) {
4248
return io.EOF // per interface spec
4349
}
4450

4551
for i, col := range r.rows[r.pos-1] {
52+
if b, ok := rawBytes(col); ok {
53+
rs.raw = append(rs.raw, b)
54+
dest[i] = b
55+
continue
56+
}
4657
dest[i] = col
4758
}
4859

@@ -80,6 +91,30 @@ func (rs *rowSets) empty() bool {
8091
return true
8192
}
8293

94+
func rawBytes(col driver.Value) (_ []byte, ok bool) {
95+
val, ok := col.([]byte)
96+
if !ok || len(val) == 0 {
97+
return nil, false
98+
}
99+
// Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later
100+
// This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close()
101+
b := make([]byte, len(val))
102+
copy(b, val)
103+
return b, true
104+
}
105+
106+
// Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close.
107+
// If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes
108+
func (rs *rowSets) invalidateRaw() {
109+
// Replace the content of slices previously returned
110+
b := []byte(invalidate)
111+
for _, r := range rs.raw {
112+
copy(r, bytes.Repeat(b, len(r)/len(b)+1))
113+
}
114+
// Start with new slices for the next scan
115+
rs.raw = nil
116+
}
117+
83118
// Rows is a mocked collection of rows to
84119
// return for Query result
85120
type Rows struct {

rows_go13_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// +build go1.3
2+
3+
package sqlmock
4+
5+
import (
6+
"database/sql"
7+
"testing"
8+
)
9+
10+
func TestQueryRowBytesNotInvalidatedByNext_stringIntoRawBytes(t *testing.T) {
11+
t.Parallel()
12+
rows := NewRows([]string{"raw"}).
13+
AddRow(`one binary value with some text!`).
14+
AddRow(`two binary value with even more text than the first one`)
15+
scan := func(rs *sql.Rows) ([]byte, error) {
16+
var raw sql.RawBytes
17+
return raw, rs.Scan(&raw)
18+
}
19+
want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
20+
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
21+
}
22+
23+
func TestQueryRowBytesNotInvalidatedByClose_stringIntoRawBytes(t *testing.T) {
24+
t.Parallel()
25+
rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`)
26+
scan := func(rs *sql.Rows) ([]byte, error) {
27+
var raw sql.RawBytes
28+
return raw, rs.Scan(&raw)
29+
}
30+
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
31+
}

rows_go18_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
package sqlmock
44

55
import (
6+
"database/sql"
7+
"encoding/json"
68
"fmt"
79
"testing"
810
)
@@ -90,3 +92,114 @@ func TestQueryMultiRows(t *testing.T) {
9092
t.Errorf("there were unfulfilled expectations: %s", err)
9193
}
9294
}
95+
96+
func TestQueryRowBytesInvalidatedByNext_jsonRawMessageIntoRawBytes(t *testing.T) {
97+
t.Parallel()
98+
replace := []byte(invalid)
99+
rows := NewRows([]string{"raw"}).
100+
AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
101+
AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
102+
scan := func(rs *sql.Rows) ([]byte, error) {
103+
var raw sql.RawBytes
104+
return raw, rs.Scan(&raw)
105+
}
106+
want := []struct {
107+
Initial []byte
108+
Replaced []byte
109+
}{
110+
{Initial: []byte(`{"thing": "one", "thing2": "two"}`), Replaced: replace[:len(replace)-6]},
111+
{Initial: []byte(`{"that": "foo", "this": "bar"}`), Replaced: replace[:len(replace)-9]},
112+
}
113+
queryRowBytesInvalidatedByNext(t, rows, scan, want)
114+
}
115+
116+
func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoBytes(t *testing.T) {
117+
t.Parallel()
118+
rows := NewRows([]string{"raw"}).
119+
AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
120+
AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
121+
scan := func(rs *sql.Rows) ([]byte, error) {
122+
var b []byte
123+
return b, rs.Scan(&b)
124+
}
125+
want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
126+
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
127+
}
128+
129+
func TestQueryRowBytesNotInvalidatedByNext_bytesIntoCustomBytes(t *testing.T) {
130+
t.Parallel()
131+
rows := NewRows([]string{"raw"}).
132+
AddRow([]byte(`one binary value with some text!`)).
133+
AddRow([]byte(`two binary value with even more text than the first one`))
134+
scan := func(rs *sql.Rows) ([]byte, error) {
135+
type customBytes []byte
136+
var b customBytes
137+
return b, rs.Scan(&b)
138+
}
139+
want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
140+
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
141+
}
142+
143+
func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoCustomBytes(t *testing.T) {
144+
t.Parallel()
145+
rows := NewRows([]string{"raw"}).
146+
AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
147+
AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
148+
scan := func(rs *sql.Rows) ([]byte, error) {
149+
type customBytes []byte
150+
var b customBytes
151+
return b, rs.Scan(&b)
152+
}
153+
want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
154+
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
155+
}
156+
157+
func TestQueryRowBytesNotInvalidatedByClose_bytesIntoCustomBytes(t *testing.T) {
158+
t.Parallel()
159+
rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`))
160+
scan := func(rs *sql.Rows) ([]byte, error) {
161+
type customBytes []byte
162+
var b customBytes
163+
return b, rs.Scan(&b)
164+
}
165+
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
166+
}
167+
168+
func TestQueryRowBytesInvalidatedByClose_jsonRawMessageIntoRawBytes(t *testing.T) {
169+
t.Parallel()
170+
replace := []byte(invalid)
171+
rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
172+
scan := func(rs *sql.Rows) ([]byte, error) {
173+
var raw sql.RawBytes
174+
return raw, rs.Scan(&raw)
175+
}
176+
want := struct {
177+
Initial []byte
178+
Replaced []byte
179+
}{
180+
Initial: []byte(`{"thing": "one", "thing2": "two"}`),
181+
Replaced: replace[:len(replace)-6],
182+
}
183+
queryRowBytesInvalidatedByClose(t, rows, scan, want)
184+
}
185+
186+
func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoBytes(t *testing.T) {
187+
t.Parallel()
188+
rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
189+
scan := func(rs *sql.Rows) ([]byte, error) {
190+
var b []byte
191+
return b, rs.Scan(&b)
192+
}
193+
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
194+
}
195+
196+
func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoCustomBytes(t *testing.T) {
197+
t.Parallel()
198+
rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
199+
scan := func(rs *sql.Rows) ([]byte, error) {
200+
type customBytes []byte
201+
var b customBytes
202+
return b, rs.Scan(&b)
203+
}
204+
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
205+
}

0 commit comments

Comments
 (0)