Skip to content

Commit 251d4ce

Browse files
committed
feat: move metadata and stats to separate result sets
1 parent f898b49 commit 251d4ce

File tree

13 files changed

+199
-71
lines changed

13 files changed

+199
-71
lines changed

checksum_row_iterator.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ func (it *checksumRowIterator) Metadata() (*sppb.ResultSetMetadata, error) {
252252
return it.metadata, nil
253253
}
254254

255-
func (it *checksumRowIterator) RowCount() int64 {
256-
return it.RowIterator.RowCount
255+
func (it *checksumRowIterator) ResultSetStats() *sppb.ResultSetStats {
256+
return &sppb.ResultSetStats{
257+
RowCount: &sppb.ResultSetStats_RowCountExact{RowCountExact: it.RowIterator.RowCount},
258+
QueryPlan: it.RowIterator.QueryPlan,
259+
}
257260
}

client_side_statement.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,6 @@ func (t *clientSideIterator) Metadata() (*spannerpb.ResultSetMetadata, error) {
445445
return t.metadata, nil
446446
}
447447

448-
func (t *clientSideIterator) RowCount() int64 {
449-
return 0
448+
func (t *clientSideIterator) ResultSetStats() *spannerpb.ResultSetStats {
449+
return &spannerpb.ResultSetStats{}
450450
}

conn.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -815,8 +815,15 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions ExecO
815815
return nil, err
816816
}
817817
}
818-
res := &rows{it: iter, decodeOption: execOptions.DecodeOption, decodeToNativeArrays: execOptions.DecodeToNativeArrays}
819-
if execOptions.DecodeOption == DecodeOptionProto {
818+
res := &rows{
819+
it: iter,
820+
decodeOption: execOptions.DecodeOption,
821+
decodeToNativeArrays: execOptions.DecodeToNativeArrays,
822+
returnResultSetMetadata: execOptions.ReturnResultSetMetadata,
823+
returnResultSetStats: execOptions.ReturnResultSetStats,
824+
}
825+
if execOptions.DirectExecute {
826+
// This forces the execution of the statement.
820827
res.getColumns()
821828
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
822829
_ = res.Close()

driver.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ type ExecOptions struct {
163163
// AutoCommitDMLMode determines the type of transaction that DML statements
164164
// that are executed outside explicit transactions use.
165165
AutocommitDMLMode AutocommitDMLMode
166+
167+
ReturnResultSetMetadata bool
168+
DirectExecute bool
169+
ReturnResultSetStats bool
166170
}
167171

168172
type DecodeOption int

merged_row_iterator.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,6 @@ func (m *mergedRowIterator) Metadata() (*sppb.ResultSetMetadata, error) {
264264
return m.metadata, nil
265265
}
266266

267-
func (m *mergedRowIterator) RowCount() int64 {
268-
return 0
267+
func (m *mergedRowIterator) ResultSetStats() *sppb.ResultSetStats {
268+
return &sppb.ResultSetStats{}
269269
}

rows.go

Lines changed: 91 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ package spannerdriver
1616

1717
import (
1818
"database/sql/driver"
19-
"encoding/base64"
19+
"errors"
2020
"fmt"
2121
"io"
2222
"sync"
@@ -27,11 +27,20 @@ import (
2727
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
2828
"github.com/google/uuid"
2929
"google.golang.org/api/iterator"
30-
"google.golang.org/protobuf/proto"
3130
"google.golang.org/protobuf/types/known/structpb"
3231
)
3332

3433
var _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil)
34+
var _ driver.RowsNextResultSet = (*rows)(nil)
35+
36+
type resultSetType int
37+
38+
const (
39+
resultSetTypeMetadata resultSetType = iota
40+
resultSetTypeResults
41+
resultSetTypeStats
42+
resultSetTypeNoMoreResults
43+
)
3544

3645
type rows struct {
3746
it rowIterator
@@ -46,6 +55,31 @@ type rows struct {
4655
decodeToNativeArrays bool
4756

4857
dirtyRow *spanner.Row
58+
59+
currentResultSetType resultSetType
60+
returnResultSetMetadata bool
61+
returnResultSetStats bool
62+
63+
hasReturnedResultSetMetadata bool
64+
hasReturnedResultSetStats bool
65+
}
66+
67+
func (r *rows) HasNextResultSet() bool {
68+
if r.currentResultSetType == resultSetTypeMetadata && r.returnResultSetMetadata {
69+
return true
70+
}
71+
if r.currentResultSetType == resultSetTypeResults && r.returnResultSetStats {
72+
return true
73+
}
74+
return false
75+
}
76+
77+
func (r *rows) NextResultSet() error {
78+
if !r.HasNextResultSet() {
79+
return io.EOF
80+
}
81+
r.currentResultSetType++
82+
return nil
4983
}
5084

5185
// Columns returns the names of the columns. The number of
@@ -54,12 +88,32 @@ type rows struct {
5488
// string should be returned for that entry.
5589
func (r *rows) Columns() []string {
5690
r.getColumns()
57-
return r.cols
91+
switch r.currentResultSetType {
92+
case resultSetTypeMetadata:
93+
return []string{"metadata"}
94+
case resultSetTypeResults:
95+
return r.cols
96+
case resultSetTypeStats:
97+
return []string{"stats"}
98+
case resultSetTypeNoMoreResults:
99+
return nil
100+
}
101+
return nil
58102
}
59103

60104
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
61105
r.getColumns()
62-
return r.colTypeNames[index]
106+
switch r.currentResultSetType {
107+
case resultSetTypeMetadata:
108+
return "ResultSetMetadata"
109+
case resultSetTypeResults:
110+
return r.colTypeNames[index]
111+
case resultSetTypeStats:
112+
return "ResultSetStats"
113+
case resultSetTypeNoMoreResults:
114+
return ""
115+
}
116+
return ""
63117
}
64118

65119
// Close closes the rows iterator.
@@ -75,6 +129,9 @@ func (r *rows) Close() error {
75129

76130
func (r *rows) getColumns() {
77131
r.colsOnce.Do(func() {
132+
if r.currentResultSetType == resultSetTypeMetadata && !r.returnResultSetMetadata {
133+
r.currentResultSetType = resultSetTypeResults
134+
}
78135
row, err := r.it.Next()
79136
if err == nil {
80137
r.dirtyRow = row
@@ -92,22 +149,10 @@ func (r *rows) getColumns() {
92149
rowType := metadata.RowType
93150
r.cols = make([]string, len(rowType.Fields))
94151
r.colTypeNames = make([]string, len(rowType.Fields))
95-
if r.decodeOption == DecodeOptionProto {
96-
if len(rowType.Fields) == 0 {
97-
r.cols = make([]string, 1)
98-
r.colTypeNames = make([]string, 1)
99-
}
100-
metadataBytes, err := proto.Marshal(metadata)
101-
if err == nil {
102-
r.colTypeNames[0] = base64.StdEncoding.EncodeToString(metadataBytes)
103-
}
104-
r.cols[0] = fmt.Sprintf("%v", r.it.RowCount())
105-
} else {
106-
for i, c := range rowType.Fields {
107-
r.cols[i] = c.Name
108-
if r.decodeOption != DecodeOptionProto {
109-
r.colTypeNames[i] = c.Type.Code.String()
110-
}
152+
for i, c := range rowType.Fields {
153+
r.cols[i] = c.Name
154+
if r.decodeOption != DecodeOptionProto {
155+
r.colTypeNames[i] = c.Type.Code.String()
111156
}
112157
}
113158
})
@@ -124,6 +169,30 @@ func (r *rows) getColumns() {
124169
// a buffer held in dest.
125170
func (r *rows) Next(dest []driver.Value) error {
126171
r.getColumns()
172+
if r.currentResultSetType == resultSetTypeMetadata {
173+
if r.dirtyErr != nil && !errors.Is(r.dirtyErr, iterator.Done) {
174+
return r.dirtyErr
175+
}
176+
if r.hasReturnedResultSetMetadata {
177+
return io.EOF
178+
}
179+
r.hasReturnedResultSetMetadata = true
180+
metadata, err := r.it.Metadata()
181+
if err != nil {
182+
return err
183+
}
184+
dest[0] = metadata
185+
return nil
186+
}
187+
if r.currentResultSetType == resultSetTypeStats {
188+
if r.hasReturnedResultSetStats {
189+
return io.EOF
190+
}
191+
r.hasReturnedResultSetStats = true
192+
dest[0] = r.it.ResultSetStats()
193+
return nil
194+
}
195+
127196
var row *spanner.Row
128197
if r.dirtyErr != nil {
129198
err := r.dirtyErr
@@ -132,7 +201,8 @@ func (r *rows) Next(dest []driver.Value) error {
132201
return io.EOF
133202
}
134203
return err
135-
} else if r.dirtyRow != nil {
204+
}
205+
if r.dirtyRow != nil {
136206
row = r.dirtyRow
137207
r.dirtyRow = nil
138208
} else {

rows_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ func (t *testIterator) Metadata() (*sppb.ResultSetMetadata, error) {
4949
return t.metadata, nil
5050
}
5151

52-
func (t *testIterator) RowCount() int64 {
53-
return 0
52+
func (t *testIterator) ResultSetStats() *sppb.ResultSetStats {
53+
return &sppb.ResultSetStats{}
5454
}
5555

5656
func newRow(t *testing.T, cols []string, vals []interface{}) *spanner.Row {

spannerlib/exported/connection.go

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import "C"
44
import (
55
"context"
66
"database/sql"
7+
"fmt"
78
"sync"
89
"sync/atomic"
910

@@ -99,11 +100,11 @@ func (conn *Connection) BeginTransaction(txOpts *spannerpb.TransactionOptions) *
99100
if txOpts.GetReadOnly() != nil {
100101
tx, err = spannerdriver.BeginReadOnlyTransactionOnConn(
101102
context.Background(), conn.backend.Conn, convertToReadOnlyOpts(txOpts))
102-
} else if txOpts.GetReadWrite() != nil {
103+
} else if txOpts.GetPartitionedDml() != nil {
104+
err = spanner.ToSpannerError(status.Error(codes.InvalidArgument, "transaction type not supported"))
105+
} else {
103106
tx, err = spannerdriver.BeginReadWriteTransactionOnConn(
104107
context.Background(), conn.backend.Conn, convertToReadWriteTransactionOptions(txOpts))
105-
} else {
106-
err = spanner.ToSpannerError(status.Error(codes.InvalidArgument, "transaction type not supported"))
107108
}
108109
if err != nil {
109110
return errMessage(err)
@@ -140,10 +141,14 @@ func convertTimestampBound(txOpts *spannerpb.TransactionOptions) spanner.Timesta
140141
}
141142

142143
func convertToReadWriteTransactionOptions(txOpts *spannerpb.TransactionOptions) spannerdriver.ReadWriteTransactionOptions {
144+
readLockMode := spannerpb.TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED
145+
if txOpts.GetReadWrite() != nil {
146+
readLockMode = txOpts.GetReadWrite().GetReadLockMode()
147+
}
143148
return spannerdriver.ReadWriteTransactionOptions{
144149
TransactionOptions: spanner.TransactionOptions{
145150
IsolationLevel: txOpts.GetIsolationLevel(),
146-
ReadLockMode: txOpts.GetReadWrite().GetReadLockMode(),
151+
ReadLockMode: readLockMode,
147152
},
148153
}
149154
}
@@ -172,9 +177,26 @@ func execute(conn *Connection, executor queryExecutor, statement *spannerpb.Exec
172177
if err != nil {
173178
return errMessage(err)
174179
}
180+
// The first result set should contain the metadata.
181+
if !it.Next() {
182+
return errMessage(fmt.Errorf("query returned no metadata"))
183+
}
184+
metadata := &spannerpb.ResultSetMetadata{}
185+
if err := it.Scan(&metadata); err != nil {
186+
return errMessage(err)
187+
}
188+
// Move to the next result set, which contains the normal data.
189+
if !it.NextResultSet() {
190+
return errMessage(fmt.Errorf("no results found after metadata"))
191+
}
175192
id := conn.resultsIdx.Add(1)
176193
res := &rows{
177-
backend: it,
194+
backend: it,
195+
metadata: metadata,
196+
}
197+
if len(metadata.RowType.Fields) == 0 {
198+
// No rows returned. Read the stats now.
199+
res.readStats()
178200
}
179201
conn.results.Store(id, res)
180202
return idMessage(id)
@@ -224,7 +246,12 @@ func extractParams(statement *spannerpb.ExecuteBatchDmlRequest_Statement) []any
224246
paramsLen = 1 + len(statement.Params.Fields)
225247
}
226248
params := make([]any, paramsLen)
227-
params = append(params, spannerdriver.ExecOptions{DecodeOption: spannerdriver.DecodeOptionProto})
249+
params = append(params, spannerdriver.ExecOptions{
250+
DecodeOption: spannerdriver.DecodeOptionProto,
251+
ReturnResultSetMetadata: true,
252+
ReturnResultSetStats: true,
253+
DirectExecute: true,
254+
})
228255
if statement.Params != nil {
229256
if statement.ParamTypes == nil {
230257
statement.ParamTypes = make(map[string]*spannerpb.Type)

spannerlib/exported/pool_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func TestExecuteDml(t *testing.T) {
6868
if len(metadataValue.RowType.Fields) > 0 {
6969
fmt.Printf("Row type: %v\n", metadataValue.RowType)
7070
} else {
71-
rowCount := UpdateCount(pool.ObjectId, conn.ObjectId, rows.ObjectId)
71+
rowCount := ResultSetStats(pool.ObjectId, conn.ObjectId, rows.ObjectId)
7272
fmt.Printf("Update count: %v\n", string(rowCount.Res))
7373
}
7474
for {

0 commit comments

Comments
 (0)