Skip to content

Commit 95ef5f3

Browse files
authored
feat: support isolation level REPEATABLE READ (#403)
Add support for isolation level REPEATABLE READ with the BeginTx function. This allows the caller to specify the isolation level for a single transaction. A follow-up pull request will add support for setting the default isolation level that should be used by a connection.
1 parent 77ab84a commit 95ef5f3

File tree

4 files changed

+202
-2
lines changed

4 files changed

+202
-2
lines changed

conn.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -942,14 +942,22 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
942942
disableRetryAborts := false
943943
batchReadOnly := false
944944
sil := opts.Isolation >> 8
945-
// TODO: Fix this, the original isolation level is not correctly restored.
946-
opts.Isolation = opts.Isolation - sil
945+
opts.Isolation = opts.Isolation - sil<<8
946+
if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) {
947+
level, err := toProtoIsolationLevel(sql.IsolationLevel(opts.Isolation))
948+
if err != nil {
949+
return nil, err
950+
}
951+
readWriteTransactionOptions.TransactionOptions.IsolationLevel = level
952+
}
947953
if sil > 0 {
948954
switch spannerIsolationLevel(sil) {
949955
case levelDisableRetryAborts:
950956
disableRetryAborts = true
951957
case levelBatchReadOnly:
952958
batchReadOnly = true
959+
default:
960+
// ignore
953961
}
954962
}
955963
if batchReadOnly && !opts.ReadOnly {

conn_with_mockserver_test.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package spannerdriver
16+
17+
import (
18+
"context"
19+
"database/sql"
20+
"reflect"
21+
"testing"
22+
23+
"cloud.google.com/go/spanner/apiv1/spannerpb"
24+
"github.com/googleapis/go-sql-spanner/testutil"
25+
)
26+
27+
func TestBeginTx(t *testing.T) {
28+
t.Parallel()
29+
30+
db, server, teardown := setupTestDBConnection(t)
31+
defer teardown()
32+
ctx := context.Background()
33+
34+
tx, _ := db.BeginTx(ctx, &sql.TxOptions{})
35+
_, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo)
36+
_ = tx.Rollback()
37+
38+
requests := drainRequestsFromServer(server.TestSpanner)
39+
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
40+
if g, w := len(beginRequests), 1; g != w {
41+
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
42+
}
43+
request := beginRequests[0].(*spannerpb.BeginTransactionRequest)
44+
if g, w := request.Options.GetIsolationLevel(), spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED; g != w {
45+
t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w)
46+
}
47+
}
48+
49+
func TestBeginTxWithIsolationLevel(t *testing.T) {
50+
t.Parallel()
51+
52+
db, server, teardown := setupTestDBConnection(t)
53+
defer teardown()
54+
ctx := context.Background()
55+
56+
for _, level := range []sql.IsolationLevel{
57+
sql.LevelDefault,
58+
sql.LevelSnapshot,
59+
sql.LevelRepeatableRead,
60+
sql.LevelSerializable,
61+
} {
62+
originalLevel := level
63+
for _, disableRetryAborts := range []bool{true, false} {
64+
if disableRetryAborts {
65+
level = WithDisableRetryAborts(originalLevel)
66+
} else {
67+
level = originalLevel
68+
}
69+
tx, _ := db.BeginTx(ctx, &sql.TxOptions{
70+
Isolation: level,
71+
})
72+
_, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo)
73+
_ = tx.Rollback()
74+
75+
requests := drainRequestsFromServer(server.TestSpanner)
76+
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
77+
if g, w := len(beginRequests), 1; g != w {
78+
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
79+
}
80+
request := beginRequests[0].(*spannerpb.BeginTransactionRequest)
81+
wantIsolationLevel, _ := toProtoIsolationLevel(originalLevel)
82+
if g, w := request.Options.GetIsolationLevel(), wantIsolationLevel; g != w {
83+
t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w)
84+
}
85+
}
86+
}
87+
}
88+
89+
func TestBeginTxWithInvalidIsolationLevel(t *testing.T) {
90+
t.Parallel()
91+
92+
db, _, teardown := setupTestDBConnection(t)
93+
defer teardown()
94+
ctx := context.Background()
95+
96+
for _, level := range []sql.IsolationLevel{
97+
sql.LevelReadUncommitted,
98+
sql.LevelReadCommitted,
99+
sql.LevelWriteCommitted,
100+
sql.LevelLinearizable,
101+
} {
102+
originalLevel := level
103+
for _, disableRetryAborts := range []bool{true, false} {
104+
if disableRetryAborts {
105+
level = WithDisableRetryAborts(originalLevel)
106+
} else {
107+
level = originalLevel
108+
}
109+
_, err := db.BeginTx(ctx, &sql.TxOptions{
110+
Isolation: level,
111+
})
112+
if err == nil {
113+
t.Fatalf("BeginTx should have failed with invalid isolation level: %v", level)
114+
}
115+
}
116+
}
117+
}

driver.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,27 @@ func checkIsValidType(v driver.Value) bool {
10581058
return true
10591059
}
10601060

1061+
func toProtoIsolationLevel(level sql.IsolationLevel) (spannerpb.TransactionOptions_IsolationLevel, error) {
1062+
switch level {
1063+
case sql.LevelSerializable:
1064+
return spannerpb.TransactionOptions_SERIALIZABLE, nil
1065+
case sql.LevelRepeatableRead:
1066+
return spannerpb.TransactionOptions_REPEATABLE_READ, nil
1067+
case sql.LevelSnapshot:
1068+
return spannerpb.TransactionOptions_REPEATABLE_READ, nil
1069+
case sql.LevelDefault:
1070+
return spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, nil
1071+
1072+
// Unsupported and unknown isolation levels.
1073+
case sql.LevelReadUncommitted:
1074+
case sql.LevelReadCommitted:
1075+
case sql.LevelWriteCommitted:
1076+
case sql.LevelLinearizable:
1077+
default:
1078+
}
1079+
return spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid or unsupported isolation level: %v", level))
1080+
}
1081+
10611082
type spannerIsolationLevel sql.IsolationLevel
10621083

10631084
const (

driver_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,61 @@ func TestExtractDnsParts(t *testing.T) {
234234
}
235235
})
236236
}
237+
}
237238

239+
func TestToProtoIsolationLevel(t *testing.T) {
240+
tests := []struct {
241+
input sql.IsolationLevel
242+
want spannerpb.TransactionOptions_IsolationLevel
243+
wantErr bool
244+
}{
245+
{
246+
input: sql.LevelSerializable,
247+
want: spannerpb.TransactionOptions_SERIALIZABLE,
248+
},
249+
{
250+
input: sql.LevelRepeatableRead,
251+
want: spannerpb.TransactionOptions_REPEATABLE_READ,
252+
},
253+
{
254+
input: sql.LevelSnapshot,
255+
want: spannerpb.TransactionOptions_REPEATABLE_READ,
256+
},
257+
{
258+
input: sql.LevelDefault,
259+
want: spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED,
260+
},
261+
{
262+
input: sql.LevelReadUncommitted,
263+
wantErr: true,
264+
},
265+
{
266+
input: sql.LevelReadCommitted,
267+
wantErr: true,
268+
},
269+
{
270+
input: sql.LevelWriteCommitted,
271+
wantErr: true,
272+
},
273+
{
274+
input: sql.LevelLinearizable,
275+
wantErr: true,
276+
},
277+
{
278+
input: sql.IsolationLevel(1000),
279+
wantErr: true,
280+
},
281+
}
282+
for i, test := range tests {
283+
g, err := toProtoIsolationLevel(test.input)
284+
if test.wantErr && err == nil {
285+
t.Errorf("test %d: expected error for input %v, got none", i, test.input)
286+
} else if !test.wantErr && err != nil {
287+
t.Errorf("test %d: unexpected error for input %v: %v", i, test.input, err)
288+
} else if g != test.want {
289+
t.Errorf("test %d:\n Got: %v\nWant: %v", i, g, test.want)
290+
}
291+
}
238292
}
239293

240294
func ExampleCreateConnector() {

0 commit comments

Comments
 (0)