@@ -431,7 +431,9 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func
431431 if err != nil {
432432 return err
433433 }
434- defer conn .Close ()
434+ defer func () {
435+ _ = conn .Close ()
436+ }()
435437
436438 // We don't need to keep track of a running checksum for retries when using
437439 // this method, so we disable internal retries.
@@ -452,7 +454,9 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func
452454 }
453455 // Reset the flag for internal retries after the transaction (if applicable).
454456 if origRetryAborts {
455- defer func () { _ = spannerConn .SetRetryAbortsInternally (origRetryAborts ) }()
457+ defer func () {
458+ _ = spannerConn .SetRetryAbortsInternally (origRetryAborts )
459+ }()
456460 }
457461
458462 tx , err := conn .BeginTx (ctx , opts )
@@ -461,11 +465,13 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func
461465 }
462466 for {
463467 err = f (ctx , tx )
468+ errDuringCommit := false
464469 if err == nil {
465470 err = tx .Commit ()
466471 if err == nil {
467472 return nil
468473 }
474+ errDuringCommit = true
469475 }
470476 // Rollback and return the error if:
471477 // 1. The connection is not a Spanner connection.
@@ -493,12 +499,23 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func
493499 }
494500 }
495501
496- // TODO: Reset the existing transaction for retry instead of creating a new one.
497- _ = tx .Rollback ()
498- tx , err = conn .BeginTx (ctx , opts )
502+ // Reset the transaction after it was aborted.
503+ err = spannerConn .resetTransactionForRetry (ctx , errDuringCommit )
499504 if err != nil {
505+ _ = tx .Rollback ()
500506 return err
501507 }
508+ // This does not actually start a new transaction, instead it
509+ // continues with the previous transaction that was already reset.
510+ // We need to do this, because the sql package registers the
511+ // transaction as 'done' when Commit has been called, also if the
512+ // commit fails.
513+ if errDuringCommit {
514+ tx , err = conn .BeginTx (ctx , opts )
515+ if err != nil {
516+ return err
517+ }
518+ }
502519 }
503520}
504521
@@ -596,17 +613,25 @@ type SpannerConn interface {
596613 // this function on different connections to the same database, can
597614 // return the same Spanner client.
598615 UnderlyingClient () (client * spanner.Client , err error )
616+
617+ // resetTransactionForRetry resets the current transaction after it has
618+ // been aborted by Spanner. Calling this function on a transaction that
619+ // has not been aborted is not supported and will cause an error to be
620+ // returned.
621+ resetTransactionForRetry (ctx context.Context , errDuringCommit bool ) error
599622}
600623
601624type conn struct {
602- connector * connector
603- closed bool
604- client * spanner.Client
605- adminClient * adminapi.DatabaseAdminClient
606- tx contextTransaction
607- commitTs * time.Time
608- database string
609- retryAborts bool
625+ connector * connector
626+ closed bool
627+ client * spanner.Client
628+ adminClient * adminapi.DatabaseAdminClient
629+ tx contextTransaction
630+ prevTx contextTransaction
631+ resetForRetry bool
632+ commitTs * time.Time
633+ database string
634+ retryAborts bool
610635
611636 execSingleQuery func (ctx context.Context , c * spanner.Client , statement spanner.Statement , bound spanner.TimestampBound ) * spanner.RowIterator
612637 execSingleDMLTransactional func (ctx context.Context , c * spanner.Client , statement spanner.Statement , transactionOptions spanner.TransactionOptions ) (int64 , time.Time , error )
@@ -1169,11 +1194,32 @@ func (c *conn) Close() error {
11691194 return c .connector .decreaseConnCount ()
11701195}
11711196
1197+ func noTransaction () error {
1198+ return status .Errorf (codes .FailedPrecondition , "connection does not have a transaction" )
1199+ }
1200+
1201+ func (c * conn ) resetTransactionForRetry (ctx context.Context , errDuringCommit bool ) error {
1202+ if errDuringCommit {
1203+ if c .prevTx == nil {
1204+ return noTransaction ()
1205+ }
1206+ c .tx = c .prevTx
1207+ c .resetForRetry = true
1208+ } else if c .tx == nil {
1209+ return noTransaction ()
1210+ }
1211+ return c .tx .resetForRetry (ctx )
1212+ }
1213+
11721214func (c * conn ) Begin () (driver.Tx , error ) {
11731215 return c .BeginTx (context .Background (), driver.TxOptions {})
11741216}
11751217
11761218func (c * conn ) BeginTx (ctx context.Context , opts driver.TxOptions ) (driver.Tx , error ) {
1219+ if c .resetForRetry {
1220+ c .resetForRetry = false
1221+ return c .tx , nil
1222+ }
11771223 if c .inTransaction () {
11781224 return nil , spanner .ToSpannerError (status .Errorf (codes .FailedPrecondition , "already in a transaction" ))
11791225 }
@@ -1202,6 +1248,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
12021248 client : c .client ,
12031249 rwTx : tx ,
12041250 close : func (commitTs * time.Time , commitErr error ) {
1251+ c .prevTx = c .tx
12051252 c .tx = nil
12061253 if commitErr == nil {
12071254 c .commitTs = commitTs
0 commit comments