@@ -16,7 +16,7 @@ package spannerdriver
1616
1717import (
1818 "database/sql/driver"
19- "encoding/base64 "
19+ "errors "
2020 "fmt"
2121 "io"
2222 "sync"
@@ -27,11 +27,20 @@ import (
2727 sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
2828 "github.com/google/uuid"
2929 "google.golang.org/api/iterator"
30- "google.golang.org/protobuf/proto"
3130 "google.golang.org/protobuf/types/known/structpb"
3231)
3332
3433var _ driver.RowsColumnTypeDatabaseTypeName = (* rows )(nil )
34+ var _ driver.RowsNextResultSet = (* rows )(nil )
35+
36+ type resultSetType int
37+
38+ const (
39+ resultSetTypeMetadata resultSetType = iota
40+ resultSetTypeResults
41+ resultSetTypeStats
42+ resultSetTypeNoMoreResults
43+ )
3544
3645type rows struct {
3746 it rowIterator
@@ -46,6 +55,31 @@ type rows struct {
4655 decodeToNativeArrays bool
4756
4857 dirtyRow * spanner.Row
58+
59+ currentResultSetType resultSetType
60+ returnResultSetMetadata bool
61+ returnResultSetStats bool
62+
63+ hasReturnedResultSetMetadata bool
64+ hasReturnedResultSetStats bool
65+ }
66+
67+ func (r * rows ) HasNextResultSet () bool {
68+ if r .currentResultSetType == resultSetTypeMetadata && r .returnResultSetMetadata {
69+ return true
70+ }
71+ if r .currentResultSetType == resultSetTypeResults && r .returnResultSetStats {
72+ return true
73+ }
74+ return false
75+ }
76+
77+ func (r * rows ) NextResultSet () error {
78+ if ! r .HasNextResultSet () {
79+ return io .EOF
80+ }
81+ r .currentResultSetType ++
82+ return nil
4983}
5084
5185// Columns returns the names of the columns. The number of
@@ -54,12 +88,32 @@ type rows struct {
5488// string should be returned for that entry.
5589func (r * rows ) Columns () []string {
5690 r .getColumns ()
57- return r .cols
91+ switch r .currentResultSetType {
92+ case resultSetTypeMetadata :
93+ return []string {"metadata" }
94+ case resultSetTypeResults :
95+ return r .cols
96+ case resultSetTypeStats :
97+ return []string {"stats" }
98+ case resultSetTypeNoMoreResults :
99+ return nil
100+ }
101+ return nil
58102}
59103
60104func (r * rows ) ColumnTypeDatabaseTypeName (index int ) string {
61105 r .getColumns ()
62- return r .colTypeNames [index ]
106+ switch r .currentResultSetType {
107+ case resultSetTypeMetadata :
108+ return "ResultSetMetadata"
109+ case resultSetTypeResults :
110+ return r .colTypeNames [index ]
111+ case resultSetTypeStats :
112+ return "ResultSetStats"
113+ case resultSetTypeNoMoreResults :
114+ return ""
115+ }
116+ return ""
63117}
64118
65119// Close closes the rows iterator.
@@ -75,6 +129,9 @@ func (r *rows) Close() error {
75129
76130func (r * rows ) getColumns () {
77131 r .colsOnce .Do (func () {
132+ if r .currentResultSetType == resultSetTypeMetadata && ! r .returnResultSetMetadata {
133+ r .currentResultSetType = resultSetTypeResults
134+ }
78135 row , err := r .it .Next ()
79136 if err == nil {
80137 r .dirtyRow = row
@@ -92,22 +149,10 @@ func (r *rows) getColumns() {
92149 rowType := metadata .RowType
93150 r .cols = make ([]string , len (rowType .Fields ))
94151 r .colTypeNames = make ([]string , len (rowType .Fields ))
95- if r .decodeOption == DecodeOptionProto {
96- if len (rowType .Fields ) == 0 {
97- r .cols = make ([]string , 1 )
98- r .colTypeNames = make ([]string , 1 )
99- }
100- metadataBytes , err := proto .Marshal (metadata )
101- if err == nil {
102- r .colTypeNames [0 ] = base64 .StdEncoding .EncodeToString (metadataBytes )
103- }
104- r .cols [0 ] = fmt .Sprintf ("%v" , r .it .RowCount ())
105- } else {
106- for i , c := range rowType .Fields {
107- r .cols [i ] = c .Name
108- if r .decodeOption != DecodeOptionProto {
109- r .colTypeNames [i ] = c .Type .Code .String ()
110- }
152+ for i , c := range rowType .Fields {
153+ r .cols [i ] = c .Name
154+ if r .decodeOption != DecodeOptionProto {
155+ r .colTypeNames [i ] = c .Type .Code .String ()
111156 }
112157 }
113158 })
@@ -124,6 +169,30 @@ func (r *rows) getColumns() {
124169// a buffer held in dest.
125170func (r * rows ) Next (dest []driver.Value ) error {
126171 r .getColumns ()
172+ if r .currentResultSetType == resultSetTypeMetadata {
173+ if r .dirtyErr != nil && ! errors .Is (r .dirtyErr , iterator .Done ) {
174+ return r .dirtyErr
175+ }
176+ if r .hasReturnedResultSetMetadata {
177+ return io .EOF
178+ }
179+ r .hasReturnedResultSetMetadata = true
180+ metadata , err := r .it .Metadata ()
181+ if err != nil {
182+ return err
183+ }
184+ dest [0 ] = metadata
185+ return nil
186+ }
187+ if r .currentResultSetType == resultSetTypeStats {
188+ if r .hasReturnedResultSetStats {
189+ return io .EOF
190+ }
191+ r .hasReturnedResultSetStats = true
192+ dest [0 ] = r .it .ResultSetStats ()
193+ return nil
194+ }
195+
127196 var row * spanner.Row
128197 if r .dirtyErr != nil {
129198 err := r .dirtyErr
@@ -132,7 +201,8 @@ func (r *rows) Next(dest []driver.Value) error {
132201 return io .EOF
133202 }
134203 return err
135- } else if r .dirtyRow != nil {
204+ }
205+ if r .dirtyRow != nil {
136206 row = r .dirtyRow
137207 r .dirtyRow = nil
138208 } else {
0 commit comments