Skip to content

Commit 463bb09

Browse files
authored
feat: return commit response for r/w transactions (#491)
Add a function for returning the entire commit response from a read/write transaction. Also expose this response in the RunTransactionWithCommitResponse function. Fixes #488
1 parent c38a8e6 commit 463bb09

File tree

7 files changed

+178
-72
lines changed

7 files changed

+178
-72
lines changed

client_side_statement_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,11 @@ func TestShowCommitTimestamp(t *testing.T) {
293293
{&ts},
294294
{nil},
295295
} {
296-
c.commitTs = test.wantValue
296+
if test.wantValue == nil {
297+
c.commitResponse = nil
298+
} else {
299+
c.commitResponse = &spanner.CommitResponse{CommitTs: *test.wantValue}
300+
}
297301

298302
it, err := s.ShowCommitTimestamp(ctx, c, "", ExecOptions{}, nil)
299303
if err != nil {

conn.go

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ type SpannerConn interface {
173173
// was executed on the connection, or an error if the connection has not executed a read/write transaction
174174
// that committed successfully. The timestamp is in the local timezone.
175175
CommitTimestamp() (commitTimestamp time.Time, err error)
176+
// CommitResponse returns the commit response of the last implicit or explicit read/write transaction that
177+
// was executed on the connection, or an error if the connection has not executed a read/write transaction
178+
// that committed successfully.
179+
CommitResponse() (commitResponse *spanner.CommitResponse, err error)
176180

177181
// UnderlyingClient returns the underlying Spanner client for the database.
178182
// The client cannot be used to access the current transaction or batch on
@@ -208,23 +212,23 @@ type SpannerConn interface {
208212
var _ SpannerConn = &conn{}
209213

210214
type conn struct {
211-
parser *statementParser
212-
connector *connector
213-
closed bool
214-
client *spanner.Client
215-
adminClient *adminapi.DatabaseAdminClient
216-
connId string
217-
logger *slog.Logger
218-
tx contextTransaction
219-
prevTx contextTransaction
220-
resetForRetry bool
221-
commitTs *time.Time
222-
database string
223-
retryAborts bool
215+
parser *statementParser
216+
connector *connector
217+
closed bool
218+
client *spanner.Client
219+
adminClient *adminapi.DatabaseAdminClient
220+
connId string
221+
logger *slog.Logger
222+
tx contextTransaction
223+
prevTx contextTransaction
224+
resetForRetry bool
225+
commitResponse *spanner.CommitResponse
226+
database string
227+
retryAborts bool
224228

225229
execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound, options ExecOptions) *spanner.RowIterator
226-
execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (rowIterator, time.Time, error)
227-
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, time.Time, error)
230+
execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (rowIterator, *spanner.CommitResponse, error)
231+
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, *spanner.CommitResponse, error)
228232
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error)
229233

230234
// batch is the currently active DDL or DML batch on this connection.
@@ -273,10 +277,17 @@ func (c *conn) UnderlyingClient() (*spanner.Client, error) {
273277
}
274278

275279
func (c *conn) CommitTimestamp() (time.Time, error) {
276-
if c.commitTs == nil {
280+
if c.commitResponse == nil {
277281
return time.Time{}, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "this connection has not executed a read/write transaction that committed successfully"))
278282
}
279-
return *c.commitTs, nil
283+
return c.commitResponse.CommitTs, nil
284+
}
285+
286+
func (c *conn) CommitResponse() (commitResponse *spanner.CommitResponse, err error) {
287+
if c.commitResponse == nil {
288+
return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "this connection has not executed a read/write transaction that committed successfully"))
289+
}
290+
return c.commitResponse, nil
280291
}
281292

282293
func (c *conn) RetryAbortsInternally() bool {
@@ -670,7 +681,7 @@ func (c *conn) ResetSession(_ context.Context) error {
670681
return driver.ErrBadConn
671682
}
672683
}
673-
c.commitTs = nil
684+
c.commitResponse = nil
674685
c.batch = nil
675686
c.autoBatchDml = c.connector.connectorConfig.AutoBatchDml
676687
c.autoBatchDmlUpdateCount = c.connector.connectorConfig.AutoBatchDmlUpdateCount
@@ -771,7 +782,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
771782

772783
func (c *conn) queryContext(ctx context.Context, query string, execOptions ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
773784
// Clear the commit timestamp of this connection before we execute the query.
774-
c.commitTs = nil
785+
c.commitResponse = nil
775786
// Check if the execution options contains an instruction to execute
776787
// a specific partition of a PartitionedQuery.
777788
if pq := execOptions.PartitionedQueryOptions.ExecutePartition.PartitionedQuery; pq != nil {
@@ -791,12 +802,12 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions ExecO
791802
if c.tx == nil {
792803
if statementType.statementType == statementTypeDml {
793804
// Use a read/write transaction to execute the statement.
794-
var commitTs time.Time
795-
iter, commitTs, err = c.execSingleQueryTransactional(ctx, c.client, stmt, execOptions)
805+
var commitResponse *spanner.CommitResponse
806+
iter, commitResponse, err = c.execSingleQueryTransactional(ctx, c.client, stmt, execOptions)
796807
if err != nil {
797808
return nil, err
798809
}
799-
c.commitTs = &commitTs
810+
c.commitResponse = commitResponse
800811
} else if execOptions.PartitionedQueryOptions.PartitionQuery {
801812
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "PartitionQuery is only supported in batch read-only transactions"))
802813
} else if execOptions.PartitionedQueryOptions.AutoPartitionQuery {
@@ -843,7 +854,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
843854

844855
func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOptions, args []driver.NamedValue) (driver.Result, error) {
845856
// Clear the commit timestamp of this connection before we execute the statement.
846-
c.commitTs = nil
857+
c.commitResponse = nil
847858

848859
statementInfo := c.parser.detectStatementType(query)
849860
// Use admin API if DDL statement is provided.
@@ -870,7 +881,7 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOp
870881
}
871882

872883
var res *result
873-
var commitTs time.Time
884+
var commitResponse *spanner.CommitResponse
874885
if c.tx == nil {
875886
if c.InDMLBatch() {
876887
c.batch.statements = append(c.batch.statements, ss)
@@ -881,9 +892,9 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOp
881892
dmlMode = execOptions.AutocommitDMLMode
882893
}
883894
if dmlMode == Transactional {
884-
res, commitTs, err = c.execSingleDMLTransactional(ctx, c.client, ss, statementInfo, execOptions)
895+
res, commitResponse, err = c.execSingleDMLTransactional(ctx, c.client, ss, statementInfo, execOptions)
885896
if err == nil {
886-
c.commitTs = &commitTs
897+
c.commitResponse = commitResponse
887898
}
888899
} else if dmlMode == PartitionedNonAtomic {
889900
var rowsAffected int64
@@ -1084,20 +1095,20 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
10841095
conn: c,
10851096
logger: logger,
10861097
rwTx: tx,
1087-
close: func(commitTs *time.Time, commitErr error) {
1098+
close: func(commitResponse *spanner.CommitResponse, commitErr error) {
10881099
if readWriteTransactionOptions.close != nil {
10891100
readWriteTransactionOptions.close()
10901101
}
10911102
c.prevTx = c.tx
10921103
c.tx = nil
10931104
if commitErr == nil {
1094-
c.commitTs = commitTs
1105+
c.commitResponse = commitResponse
10951106
}
10961107
},
10971108
// Disable internal retries if any of these options have been set.
10981109
retryAborts: !readWriteTransactionOptions.DisableInternalRetries && !disableRetryAborts,
10991110
}
1100-
c.commitTs = nil
1111+
c.commitResponse = nil
11011112
return c.tx, nil
11021113
}
11031114

@@ -1153,7 +1164,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ar
11531164
return r, nil
11541165
}
11551166

1156-
func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (rowIterator, time.Time, error) {
1167+
func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (rowIterator, *spanner.CommitResponse, error) {
11571168
var result *wrappedRowIterator
11581169
options.QueryOptions.LastStatement = true
11591170
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
@@ -1177,14 +1188,14 @@ func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement s
11771188
}
11781189
resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options.TransactionOptions)
11791190
if err != nil {
1180-
return nil, time.Time{}, err
1191+
return nil, nil, err
11811192
}
1182-
return result, resp.CommitTs, nil
1193+
return result, &resp, nil
11831194
}
11841195

11851196
var errInvalidDmlForExecContext = spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "Exec and ExecContext can only be used with INSERT statements with a THEN RETURN clause that return exactly one row with one column of type INT64. Use Query or QueryContext for DML statements other than INSERT and/or with THEN RETURN clauses that return other/more data."))
11861197

1187-
func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, time.Time, error) {
1198+
func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, *spanner.CommitResponse, error) {
11881199
var res *result
11891200
options.QueryOptions.LastStatement = true
11901201
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
@@ -1197,9 +1208,9 @@ func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement sp
11971208
}
11981209
resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options.TransactionOptions)
11991210
if err != nil {
1200-
return &result{}, time.Time{}, err
1211+
return &result{}, nil, err
12011212
}
1202-
return res, resp.CommitTs, nil
1213+
return res, &resp, nil
12031214
}
12041215

12051216
func execTransactionalDML(ctx context.Context, tx spannerTransaction, statement spanner.Statement, statementInfo *statementInfo, options spanner.QueryOptions) (*result, error) {

driver.go

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,8 @@ func (c *connector) closeClients() (err error) {
851851
//
852852
// This function will never return ErrAbortedDueToConcurrentModification.
853853
func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error) error {
854-
return runTransactionWithOptions(ctx, db, opts, f, spanner.TransactionOptions{})
854+
_, err := runTransactionWithOptions(ctx, db, opts, f, spanner.TransactionOptions{})
855+
return err
855856
}
856857

857858
// RunTransactionWithOptions runs the given function in a transaction on the given database.
@@ -873,18 +874,44 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func
873874
//
874875
// This function will never return ErrAbortedDueToConcurrentModification.
875876
func RunTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions spanner.TransactionOptions) error {
877+
_, err := runTransactionWithOptions(ctx, db, opts, f, spannerOptions)
878+
return err
879+
}
880+
881+
// RunTransactionWithCommitResponse runs the given function in a transaction on
882+
// the given database. If the connection is a connection to a Spanner database,
883+
// the transaction will automatically be retried if the transaction is aborted
884+
// by Spanner. Any other errors will be propagated to the caller and the
885+
// transaction will be rolled back. The transaction will be committed if the
886+
// supplied function did not return an error.
887+
//
888+
// If the connection is to a non-Spanner database, no retries will be attempted,
889+
// and any error that occurs during the transaction will be propagated to the
890+
// caller.
891+
//
892+
// The application should *NOT* call tx.Commit() or tx.Rollback(). This is done
893+
// automatically by this function, depending on whether the transaction function
894+
// returned an error or not.
895+
//
896+
// The given spanner.TransactionOptions will be used for the transaction.
897+
//
898+
// This function returns a spanner.CommitResponse if the transaction committed
899+
// successfully.
900+
//
901+
// This function will never return ErrAbortedDueToConcurrentModification.
902+
func RunTransactionWithCommitResponse(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions spanner.TransactionOptions) (*spanner.CommitResponse, error) {
876903
return runTransactionWithOptions(ctx, db, opts, f, spannerOptions)
877904
}
878905

879-
func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions spanner.TransactionOptions) error {
906+
func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions spanner.TransactionOptions) (*spanner.CommitResponse, error) {
880907
// Get a connection from the pool that we can use to run a transaction.
881908
// Getting a connection here already makes sure that we can reserve this
882909
// connection exclusively for the duration of this method. That again
883910
// allows us to temporarily change the state of the connection (e.g. set
884911
// the retryAborts flag to false).
885912
conn, err := db.Conn(ctx)
886913
if err != nil {
887-
return err
914+
return nil, err
888915
}
889916
defer func() {
890917
_ = conn.Close()
@@ -908,20 +935,24 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti
908935
spannerConn.withTempTransactionOptions(transactionOptions)
909936
return nil
910937
}); err != nil {
911-
return err
938+
return nil, err
912939
}
913940

914941
tx, err := conn.BeginTx(ctx, opts)
915942
if err != nil {
916-
return err
943+
return nil, err
917944
}
918945
for {
919946
err = protected(ctx, tx, f)
920947
errDuringCommit := false
921948
if err == nil {
922949
err = tx.Commit()
923950
if err == nil {
924-
return nil
951+
resp, err := getCommitResponse(conn)
952+
if err != nil {
953+
return nil, err
954+
}
955+
return resp, nil
925956
}
926957
errDuringCommit = true
927958
}
@@ -934,7 +965,7 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti
934965
// and just returns an ErrTxDone if we do, so this is simpler than
935966
// keeping track of where the error happened.
936967
_ = tx.Rollback()
937-
return err
968+
return nil, err
938969
}
939970

940971
// The transaction was aborted by Spanner.
@@ -947,15 +978,15 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti
947978
// anymore. It does not actually roll back the transaction, as it
948979
// has already been aborted by Spanner.
949980
_ = tx.Rollback()
950-
return err
981+
return nil, err
951982
}
952983
}
953984

954985
// Reset the transaction after it was aborted.
955986
err = resetTransactionForRetry(ctx, conn, errDuringCommit)
956987
if err != nil {
957988
_ = tx.Rollback()
958-
return err
989+
return nil, err
959990
}
960991
// This does not actually start a new transaction, instead it
961992
// continues with the previous transaction that was already reset.
@@ -965,12 +996,13 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti
965996
if errDuringCommit {
966997
tx, err = conn.BeginTx(ctx, opts)
967998
if err != nil {
968-
return err
999+
return nil, err
9691000
}
9701001
}
9711002
}
9721003

9731004
}
1005+
9741006
func protected(ctx context.Context, tx *sql.Tx, f func(ctx context.Context, tx *sql.Tx) error) (err error) {
9751007
defer func() {
9761008
if x := recover(); x != nil {
@@ -990,6 +1022,20 @@ func resetTransactionForRetry(ctx context.Context, conn *sql.Conn, errDuringComm
9901022
})
9911023
}
9921024

1025+
func getCommitResponse(conn *sql.Conn) (resp *spanner.CommitResponse, err error) {
1026+
if err := conn.Raw(func(driverConn any) error {
1027+
spannerConn, ok := driverConn.(SpannerConn)
1028+
if !ok {
1029+
return spanner.ToSpannerError(status.Error(codes.InvalidArgument, "not a Spanner connection"))
1030+
}
1031+
resp, err = spannerConn.CommitResponse()
1032+
return err
1033+
}); err != nil {
1034+
return nil, err
1035+
}
1036+
return resp, nil
1037+
}
1038+
9931039
type ReadWriteTransactionOptions struct {
9941040
// TransactionOptions are passed through to the Spanner client to use for
9951041
// the read/write transaction.

0 commit comments

Comments
 (0)