@@ -32,6 +32,7 @@ import (
3232 adminapi "cloud.google.com/go/spanner/admin/database/apiv1"
3333 adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
3434 "cloud.google.com/go/spanner/apiv1/spannerpb"
35+ "github.com/googleapis/gax-go/v2"
3536 "google.golang.org/api/option"
3637 "google.golang.org/grpc"
3738 "google.golang.org/grpc/codes"
@@ -403,6 +404,103 @@ func (c *connector) closeClients() (err error) {
403404 return err
404405}
405406
407+ // RunTransaction runs the given function in a transaction on the given database.
408+ // If the connection is a connection to a Spanner database, the transaction will
409+ // automatically be retried if the transaction is aborted by Spanner. Any other
410+ // errors will be propagated to the caller and the transaction will be rolled
411+ // back. The transaction will be committed if the supplied function did not
412+ // return an error.
413+ //
414+ // If the connection is to a non-Spanner database, no retries will be attempted,
415+ // and any error that occurs during the transaction will be propagated to the
416+ // caller.
417+ //
418+ // The application should *NOT* call tx.Commit() or tx.Rollback(). This is done
419+ // automatically by this function, depending on whether the transaction function
420+ // returned an error or not.
421+ //
422+ // This function will never return ErrAbortedDueToConcurrentModification.
423+ func RunTransaction (ctx context.Context , db * sql.DB , opts * sql.TxOptions , f func (ctx context.Context , tx * sql.Tx ) error ) error {
424+ // Get a connection from the pool that we can use to run a transaction.
425+ // Getting a connection here already makes sure that we can reserve this
426+ // connection exclusively for the duration of this method. That again
427+ // allows us to temporarily change the state of the connection (e.g. set
428+ // the retryAborts flag to false).
429+ conn , err := db .Conn (ctx )
430+ if err != nil {
431+ return err
432+ }
433+ defer conn .Close ()
434+
435+ // We don't need to keep track of a running checksum for retries when using
436+ // this method, so we disable internal retries.
437+ // Retries will instead be handled by the loop below.
438+ origRetryAborts := false
439+ var spannerConn SpannerConn
440+ if err := conn .Raw (func (driverConn any ) error {
441+ var ok bool
442+ spannerConn , ok = driverConn .(SpannerConn )
443+ if ! ok {
444+ // It is not a Spanner connection, so just ignore and continue without any special handling.
445+ return nil
446+ }
447+ origRetryAborts = spannerConn .RetryAbortsInternally ()
448+ return spannerConn .SetRetryAbortsInternally (false )
449+ }); err != nil {
450+ return err
451+ }
452+ // Reset the flag for internal retries after the transaction (if applicable).
453+ if origRetryAborts {
454+ defer func () { _ = spannerConn .SetRetryAbortsInternally (origRetryAborts ) }()
455+ }
456+
457+ tx , err := conn .BeginTx (ctx , opts )
458+ if err != nil {
459+ return err
460+ }
461+ for {
462+ err = f (ctx , tx )
463+ if err == nil {
464+ err = tx .Commit ()
465+ if err == nil {
466+ return nil
467+ }
468+ }
469+ // Rollback and return the error if:
470+ // 1. The connection is not a Spanner connection.
471+ // 2. Or the error code is not Aborted.
472+ if spannerConn == nil || spanner .ErrCode (err ) != codes .Aborted {
473+ // We don't really need to call Rollback here if the error happened
474+ // during the Commit. However, the SQL package treats this as a no-op
475+ // and just returns an ErrTxDone if we do, so this is simpler than
476+ // keeping track of where the error happened.
477+ _ = tx .Rollback ()
478+ return err
479+ }
480+
481+ // The transaction was aborted by Spanner.
482+ // Back off and retry the entire transaction.
483+ if delay , ok := spanner .ExtractRetryDelay (err ); ok {
484+ err = gax .Sleep (ctx , delay )
485+ if err != nil {
486+ // We need to 'roll back' the transaction here to tell the sql
487+ // package that there is no active transaction on the connection
488+ // anymore. It does not actually roll back the transaction, as it
489+ // has already been aborted by Spanner.
490+ _ = tx .Rollback ()
491+ return err
492+ }
493+ }
494+
495+ // TODO: Reset the existing transaction for retry instead of creating a new one.
496+ _ = tx .Rollback ()
497+ tx , err = conn .BeginTx (ctx , opts )
498+ if err != nil {
499+ return err
500+ }
501+ }
502+ }
503+
406504// SpannerConn is the public interface for the raw Spanner connection for the
407505// sql driver. This interface can be used with the db.Conn().Raw() method.
408506type SpannerConn interface {
0 commit comments