Skip to content

Commit 9c99bcb

Browse files
committed
chore: add more tests
1 parent 16dd72e commit 9c99bcb

File tree

5 files changed

+284
-2
lines changed

5 files changed

+284
-2
lines changed

spannerlib/api/connection.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"cloud.google.com/go/spanner"
1313
"cloud.google.com/go/spanner/apiv1/spannerpb"
1414
spannerdriver "github.com/googleapis/go-sql-spanner"
15+
"github.com/googleapis/go-sql-spanner/parser"
1516
"google.golang.org/grpc/codes"
1617
"google.golang.org/grpc/status"
1718
"google.golang.org/protobuf/types/known/timestamppb"
@@ -350,9 +351,9 @@ func determineBatchType(conn *Connection, statements []*spannerpb.ExecuteBatchDm
350351
if err := conn.backend.Conn.Raw(func(driverConn any) error {
351352
spannerConn, _ := driverConn.(spannerdriver.SpannerConn)
352353
firstStatementType := spannerConn.DetectStatementType(statements[0].Sql)
353-
if firstStatementType == spannerdriver.StatementTypeDml {
354+
if firstStatementType == parser.StatementTypeDml {
354355
batchType = spannerdriver.BatchTypeDml
355-
} else if firstStatementType == spannerdriver.StatementTypeDdl {
356+
} else if firstStatementType == parser.StatementTypeDdl {
356357
batchType = spannerdriver.BatchTypeDdl
357358
} else {
358359
return status.Errorf(codes.InvalidArgument, "unsupported statement type for batching: %s", firstStatementType)
File renamed without changes.

spannerlib/shared_lib.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ func ResultSetStats(poolId, connId, rowsId int64) (int64, int32, int64, int32, u
155155
// ListValue that contains all the columns of the row. The message is empty if there are
156156
// no more rows in the Rows object.
157157
//
158+
// TODO: Add support for:
159+
// 1. Fetching more than one row at a time.
160+
// 2. Specifying the return type (e.g. proto, struct, ...)
161+
//
158162
//export Next
159163
func Next(poolId, connId, rowsId int64) (int64, int32, int64, int32, unsafe.Pointer) {
160164
msg := lib.Next(poolId, connId, rowsId)

spannerlib/shared_lib_test.go

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"reflect"
6+
"testing"
7+
"unsafe"
8+
9+
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
10+
"cloud.google.com/go/spanner/apiv1/spannerpb"
11+
"github.com/google/uuid"
12+
"github.com/googleapis/go-sql-spanner/testutil"
13+
"google.golang.org/grpc/codes"
14+
"google.golang.org/protobuf/proto"
15+
"google.golang.org/protobuf/types/known/structpb"
16+
)
17+
18+
func TestCreateAndClosePool(t *testing.T) {
19+
server, teardown := setupMockServer(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL)
20+
defer teardown()
21+
22+
project := "my-project"
23+
instance := "my-instance"
24+
database := "my-database"
25+
dsn := fmt.Sprintf("//%s/projects/%s/instances/%s/databases/%s;usePlainText=true", server.Address, project, instance, database)
26+
pinner, code, poolId, length, ptr := CreatePool(dsn)
27+
verifyEmptyIdMessage(t, "CreatePool", pinner, code, poolId, length, ptr)
28+
29+
pinner, code, _, length, ptr = ClosePool(poolId)
30+
verifyEmptyMessage(t, "ClosePool", pinner, code, length, ptr)
31+
}
32+
33+
func TestCreateAndCloseConnection(t *testing.T) {
34+
server, teardown := setupMockServer(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL)
35+
defer teardown()
36+
37+
project := "my-project"
38+
instance := "my-instance"
39+
database := "my-database"
40+
dsn := fmt.Sprintf("//%s/projects/%s/instances/%s/databases/%s;usePlainText=true", server.Address, project, instance, database)
41+
_, code, poolId, _, _ := CreatePool(dsn)
42+
if g, w := code, int32(codes.OK); g != w {
43+
t.Fatalf("CreatePool returned non-OK code\n Got: %v\nWant: %d", g, w)
44+
}
45+
defer ClosePool(poolId)
46+
47+
pinner, code, connId, length, ptr := CreateConnection(poolId)
48+
verifyEmptyIdMessage(t, "CreateConnection", pinner, code, connId, length, ptr)
49+
50+
pinner, code, _, length, ptr = CloseConnection(poolId, connId)
51+
verifyEmptyMessage(t, "CloseConnection", pinner, code, length, ptr)
52+
}
53+
54+
func TestApply(t *testing.T) {
55+
poolId, connId, server, teardown := setupTestConnection(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL)
56+
defer teardown()
57+
58+
mutations := &spannerpb.BatchWriteRequest_MutationGroup{
59+
Mutations: []*spannerpb.Mutation{
60+
{Operation: &spannerpb.Mutation_Insert{
61+
Insert: &spannerpb.Mutation_Write{
62+
Table: "my_table",
63+
Columns: []string{"col1", "col2", "col3", "col4"},
64+
Values: []*structpb.ListValue{
65+
{Values: []*structpb.Value{
66+
{Kind: &structpb.Value_StringValue{StringValue: "val1"}},
67+
{Kind: &structpb.Value_NumberValue{NumberValue: 3.14}},
68+
{Kind: &structpb.Value_NullValue{NullValue: structpb.NullValue_NULL_VALUE}},
69+
{Kind: &structpb.Value_BoolValue{BoolValue: true}},
70+
}},
71+
{Values: []*structpb.Value{
72+
{Kind: &structpb.Value_StringValue{StringValue: "val2"}},
73+
{Kind: &structpb.Value_NumberValue{NumberValue: 6.626}},
74+
{Kind: &structpb.Value_NullValue{NullValue: structpb.NullValue_NULL_VALUE}},
75+
{Kind: &structpb.Value_BoolValue{BoolValue: false}},
76+
}},
77+
},
78+
},
79+
}},
80+
},
81+
}
82+
mutationBytes, err := proto.Marshal(mutations)
83+
if err != nil {
84+
t.Fatal(err)
85+
}
86+
pinner, code, id, length, ptr := Apply(poolId, connId, mutationBytes)
87+
verifyNonEmptyMessage(t, "Apply", pinner, code, id, length, ptr)
88+
defer Release(pinner)
89+
90+
responseBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), ptr, int(length)).Bytes()
91+
response := &spannerpb.CommitResponse{}
92+
if err := proto.Unmarshal(responseBytes, response); err != nil {
93+
t.Fatal(err)
94+
}
95+
if response.CommitTimestamp == nil {
96+
t.Fatal("CommitTimestamp is nil")
97+
}
98+
99+
requests := server.TestSpanner.DrainRequestsFromServer()
100+
beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
101+
if g, w := len(beginRequests), 1; g != w {
102+
t.Fatalf("num begin requests mismatch\n Got: %v\nWant: %v", g, w)
103+
}
104+
commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{}))
105+
if g, w := len(commitRequests), 1; g != w {
106+
t.Fatalf("num commit requests mismatch\n Got: %v\nWant: %v", g, w)
107+
}
108+
commitRequest := commitRequests[0].(*spannerpb.CommitRequest)
109+
if g, w := len(commitRequest.Mutations), 1; g != w {
110+
t.Fatalf("num mutations mismatch\n Got: %v\nWant: %v", g, w)
111+
}
112+
if g, w := commitRequest.Mutations[0].GetInsert().GetTable(), "my_table"; g != w {
113+
t.Fatalf("insert table mismatch\n Got: %v\nWant: %v", g, w)
114+
}
115+
if g, w := len(commitRequest.Mutations[0].GetInsert().GetValues()), 2; g != w {
116+
t.Fatalf("num rows mismatch\n Got: %v\nWant: %v", g, w)
117+
}
118+
}
119+
120+
func TestExecute(t *testing.T) {
121+
poolId, connId, server, teardown := setupTestConnection(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL)
122+
defer teardown()
123+
124+
query := "select * from my_table"
125+
_ = server.TestSpanner.PutStatementResult(query, &testutil.StatementResult{
126+
Type: testutil.StatementResultResultSet,
127+
ResultSet: generateResultSet(5),
128+
})
129+
130+
request := &spannerpb.ExecuteSqlRequest{
131+
Sql: query,
132+
}
133+
requestBytes, err := proto.Marshal(request)
134+
if err != nil {
135+
t.Fatal(err)
136+
}
137+
pinner, code, rowsId, length, ptr := Execute(poolId, connId, requestBytes)
138+
verifyEmptyIdMessage(t, "Execute", pinner, code, rowsId, length, ptr)
139+
140+
requests := server.TestSpanner.DrainRequestsFromServer()
141+
executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
142+
if g, w := len(executeRequests), 1; g != w {
143+
t.Fatalf("num execute requests mismatch\n Got: %v\nWant: %v", g, w)
144+
}
145+
}
146+
147+
func verifyEmptyIdMessage(t *testing.T, name string, pinner int64, code int32, id int64, length int32, ptr unsafe.Pointer) {
148+
verifyEmptyMessage(t, name, pinner, code, length, ptr)
149+
if id <= int64(0) {
150+
t.Fatalf("%s returned zero or negative id\n Got: %v", name, id)
151+
}
152+
}
153+
154+
func verifyEmptyMessage(t *testing.T, name string, pinner int64, code int32, length int32, ptr unsafe.Pointer) {
155+
if pinner != int64(0) && code != int32(0) {
156+
msg := unsafe.String((*byte)(ptr), length)
157+
t.Fatalf("%s returned error: %v", name, msg)
158+
}
159+
160+
if g, w := pinner, int64(0); g != w {
161+
t.Fatalf("%s returned non-zero pinner\n Got: %v\nWant: %v", name, g, w)
162+
}
163+
if g, w := code, int32(codes.OK); g != w {
164+
t.Fatalf("%s returned non-OK code\n Got: %v\nWant: %d", name, g, w)
165+
}
166+
if g, w := length, int32(0); g != w {
167+
t.Fatalf("%s returned non-empty length\n Got: %v\nWant: %v", name, g, w)
168+
}
169+
if g, w := ptr, unsafe.Pointer(nil); g != w {
170+
t.Fatalf("%s returned non-nil pointer\n Got: %v\nWant: %v", name, g, w)
171+
}
172+
}
173+
174+
func verifyNonEmptyMessage(t *testing.T, name string, pinner int64, code int32, id int64, length int32, ptr unsafe.Pointer) {
175+
if pinner == int64(0) {
176+
t.Fatalf("%s returned zero pinner", name)
177+
}
178+
if g, w := code, int32(codes.OK); g != w {
179+
t.Fatalf("%s returned non-OK code\n Got: %v\nWant: %d", name, g, w)
180+
}
181+
if g, w := id, int64(0); g != w {
182+
t.Fatalf("%s returned non-zero id\n Got: %v\nWant: %v", name, g, w)
183+
}
184+
if length == int32(0) {
185+
t.Fatalf("%s returned empty length", name)
186+
}
187+
if ptr == unsafe.Pointer(nil) {
188+
t.Fatalf("%s returned nil pointer", name)
189+
}
190+
}
191+
192+
func generateResultSet(numRows int) *spannerpb.ResultSet {
193+
res := &spannerpb.ResultSet{
194+
Metadata: &spannerpb.ResultSetMetadata{
195+
RowType: &spannerpb.StructType{
196+
Fields: []*spannerpb.StructType_Field{
197+
{Type: &spannerpb.Type{Code: spannerpb.TypeCode_STRING}, Name: "col1"},
198+
{Type: &spannerpb.Type{Code: spannerpb.TypeCode_STRING}, Name: "col2"},
199+
{Type: &spannerpb.Type{Code: spannerpb.TypeCode_STRING}, Name: "col3"},
200+
{Type: &spannerpb.Type{Code: spannerpb.TypeCode_STRING}, Name: "col4"},
201+
{Type: &spannerpb.Type{Code: spannerpb.TypeCode_STRING}, Name: "col5"},
202+
},
203+
},
204+
},
205+
}
206+
rows := make([]*structpb.ListValue, 0, numRows)
207+
for i := 0; i < numRows; i++ {
208+
values := make([]*structpb.Value, 0, 5)
209+
for j := 0; j < 5; j++ {
210+
values = append(values, &structpb.Value{
211+
Kind: &structpb.Value_StringValue{StringValue: uuid.New().String()},
212+
})
213+
}
214+
rows = append(rows, &structpb.ListValue{
215+
Values: values,
216+
})
217+
}
218+
res.Rows = rows
219+
return res
220+
}
221+
222+
func setupMockServer(t *testing.T, dialect databasepb.DatabaseDialect) (server *testutil.MockedSpannerInMemTestServer, teardown func()) {
223+
server, _, serverTeardown := testutil.NewMockedSpannerInMemTestServer(t)
224+
server.SetupSelectDialectResult(dialect)
225+
226+
return server, serverTeardown
227+
}
228+
229+
func setupTestConnection(t *testing.T, dialect databasepb.DatabaseDialect) (poolId int64, connId int64, server *testutil.MockedSpannerInMemTestServer, teardown func()) {
230+
server, serverTeardown := setupMockServer(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL)
231+
232+
project := "my-project"
233+
instance := "my-instance"
234+
database := "my-database"
235+
dsn := fmt.Sprintf("//%s/projects/%s/instances/%s/databases/%s;usePlainText=true", server.Address, project, instance, database)
236+
_, code, poolId, _, _ := CreatePool(dsn)
237+
if code != int32(codes.OK) {
238+
t.Fatalf("CreatePool returned non-OK code\n Got: %v", code)
239+
}
240+
_, code, connId, _, _ = CreateConnection(poolId)
241+
if code != int32(codes.OK) {
242+
t.Fatalf("CreateConnection returned non-OK code\n Got: %v", code)
243+
}
244+
245+
return poolId, connId, server, func() {
246+
CloseConnection(poolId, connId)
247+
ClosePool(poolId)
248+
serverTeardown()
249+
}
250+
}

testutil/inmem_spanner_server.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"encoding/binary"
2121
"fmt"
2222
"math/rand"
23+
"reflect"
2324
"sort"
2425
"strings"
2526
"sync"
@@ -330,6 +331,8 @@ type InMemSpannerServer interface {
330331
ClearPings()
331332
DumpPings() []string
332333
IsPartitionedDmlTransaction(id []byte) bool
334+
335+
DrainRequestsFromServer() []interface{}
333336
}
334337

335338
type inMemSpannerServer struct {
@@ -1187,6 +1190,20 @@ func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.P
11871190
})
11881191
}
11891192

1193+
func (s *inMemSpannerServer) DrainRequestsFromServer() []interface{} {
1194+
var reqs []interface{}
1195+
loop:
1196+
for {
1197+
select {
1198+
case req := <-s.ReceivedRequests():
1199+
reqs = append(reqs, req)
1200+
default:
1201+
break loop
1202+
}
1203+
}
1204+
return reqs
1205+
}
1206+
11901207
// EncodeResumeToken return mock resume token encoding for an uint64 integer.
11911208
func EncodeResumeToken(t uint64) []byte {
11921209
rt := make([]byte, 16)
@@ -1202,3 +1219,13 @@ func DecodeResumeToken(t []byte) (uint64, error) {
12021219
}
12031220
return s, nil
12041221
}
1222+
1223+
func RequestsOfType(requests []interface{}, t reflect.Type) []interface{} {
1224+
res := make([]interface{}, 0)
1225+
for _, req := range requests {
1226+
if reflect.TypeOf(req) == t {
1227+
res = append(res, req)
1228+
}
1229+
}
1230+
return res
1231+
}

0 commit comments

Comments
 (0)