@@ -44,7 +44,7 @@ type contextTransaction interface {
4444 Commit () error
4545 Rollback () error
4646 resetForRetry (ctx context.Context ) error
47- Query (ctx context.Context , stmt spanner.Statement , execOptions * ExecOptions ) (rowIterator , error )
47+ Query (ctx context.Context , stmt spanner.Statement , stmtType parser. StatementType , execOptions * ExecOptions ) (rowIterator , error )
4848 partitionQuery (ctx context.Context , stmt spanner.Statement , execOptions * ExecOptions ) (driver.Rows , error )
4949 ExecContext (ctx context.Context , stmt spanner.Statement , statementInfo * parser.StatementInfo , options spanner.QueryOptions ) (* result , error )
5050
@@ -67,6 +67,7 @@ var _ rowIterator = &readOnlyRowIterator{}
6767
6868type readOnlyRowIterator struct {
6969 * spanner.RowIterator
70+ stmtType parser.StatementType
7071}
7172
7273func (ri * readOnlyRowIterator ) Next () (* spanner.Row , error ) {
@@ -84,10 +85,13 @@ func (ri *readOnlyRowIterator) Metadata() (*sppb.ResultSetMetadata, error) {
8485func (ri * readOnlyRowIterator ) ResultSetStats () * sppb.ResultSetStats {
8586 // TODO: The Spanner client library should offer an option to get the full
8687 // ResultSetStats, instead of only the RowCount and QueryPlan.
87- return & sppb.ResultSetStats {
88- RowCount : & sppb.ResultSetStats_RowCountExact {RowCountExact : ri .RowIterator .RowCount },
88+ stats := & sppb.ResultSetStats {
8989 QueryPlan : ri .RowIterator .QueryPlan ,
9090 }
91+ if ri .stmtType == parser .StatementTypeDml {
92+ stats .RowCount = & sppb.ResultSetStats_RowCountExact {RowCountExact : ri .RowIterator .RowCount }
93+ }
94+ return stats
9195}
9296
9397type txResult int
@@ -135,7 +139,7 @@ func (tx *readOnlyTransaction) resetForRetry(ctx context.Context) error {
135139 return nil
136140}
137141
138- func (tx * readOnlyTransaction ) Query (ctx context.Context , stmt spanner.Statement , execOptions * ExecOptions ) (rowIterator , error ) {
142+ func (tx * readOnlyTransaction ) Query (ctx context.Context , stmt spanner.Statement , stmtType parser. StatementType , execOptions * ExecOptions ) (rowIterator , error ) {
139143 tx .logger .DebugContext (ctx , "Query" , "stmt" , stmt .SQL )
140144 if execOptions .PartitionedQueryOptions .AutoPartitionQuery {
141145 if tx .boTx == nil {
@@ -152,7 +156,7 @@ func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement
152156 }
153157 return mi , nil
154158 }
155- return & readOnlyRowIterator {tx .roTx .QueryWithOptions (ctx , stmt , execOptions .QueryOptions )}, nil
159+ return & readOnlyRowIterator {tx .roTx .QueryWithOptions (ctx , stmt , execOptions .QueryOptions ), stmtType }, nil
156160}
157161
158162func (tx * readOnlyTransaction ) partitionQuery (ctx context.Context , stmt spanner.Statement , execOptions * ExecOptions ) (driver.Rows , error ) {
@@ -456,7 +460,7 @@ func (tx *readWriteTransaction) resetForRetry(ctx context.Context) error {
456460// Query executes a query using the read/write transaction and returns a
457461// rowIterator that will automatically retry the read/write transaction if the
458462// transaction is aborted during the query or while iterating the returned rows.
459- func (tx * readWriteTransaction ) Query (ctx context.Context , stmt spanner.Statement , execOptions * ExecOptions ) (rowIterator , error ) {
463+ func (tx * readWriteTransaction ) Query (ctx context.Context , stmt spanner.Statement , stmtType parser. StatementType , execOptions * ExecOptions ) (rowIterator , error ) {
460464 tx .logger .Debug ("Query" , "stmt" , stmt .SQL )
461465 tx .active = true
462466 if err := tx .maybeRunAutoDmlBatch (ctx ); err != nil {
@@ -465,7 +469,7 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen
465469 // If internal retries have been disabled, we don't need to keep track of a
466470 // running checksum for all results that we have seen.
467471 if ! tx .retryAborts () {
468- return & readOnlyRowIterator {tx .rwTx .QueryWithOptions (ctx , stmt , execOptions .QueryOptions )}, nil
472+ return & readOnlyRowIterator {tx .rwTx .QueryWithOptions (ctx , stmt , execOptions .QueryOptions ), stmtType }, nil
469473 }
470474
471475 // If retries are enabled, we need to use a row iterator that will keep
@@ -476,6 +480,7 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen
476480 ctx : ctx ,
477481 tx : tx ,
478482 stmt : stmt ,
483+ stmtType : stmtType ,
479484 options : execOptions .QueryOptions ,
480485 buffer : buffer ,
481486 enc : gob .NewEncoder (buffer ),
0 commit comments