Skip to content

Commit 5da0cbe

Browse files
committed
chore: use a callback to supply tx opts
Use a callback to supply transaction options, so changes to the connection variables at the start of a transaction (before it has actually been activated) are also included in the transaction. This is necessary to support SET LOCAL statements that have an impact on the actual transaction, such as the following example script: ``` BEGIN TRANSACTION; SET LOCAL ISOLATION_LEVEL='repeatable_read'; UPDATE my_table SET my_col=1 WHERE id=1; COMMIT; ``` This change depends on googleapis/google-cloud-go#12779
1 parent 7f78b3b commit 5da0cbe

File tree

5 files changed

+3113
-20
lines changed

5 files changed

+3113
-20
lines changed

conn.go

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,10 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOp
926926
func (c *conn) options(reset bool) ExecOptions {
927927
if reset {
928928
defer func() {
929-
c.execOptions.TransactionOptions.TransactionTag = ""
929+
// Only reset the transaction tag if there is no active transaction on the connection.
930+
if !c.inTransaction() {
931+
c.execOptions.TransactionOptions.TransactionTag = ""
932+
}
930933
c.execOptions.QueryOptions.RequestTag = ""
931934
}()
932935
}
@@ -958,7 +961,7 @@ func (c *conn) withTempTransactionOptions(options *ReadWriteTransactionOptions)
958961
c.tempTransactionOptions = options
959962
}
960963

961-
func (c *conn) getTransactionOptions() ReadWriteTransactionOptions {
964+
func (c *conn) getTransactionOptions(execOptions ExecOptions) ReadWriteTransactionOptions {
962965
if c.tempTransactionOptions != nil {
963966
defer func() { c.tempTransactionOptions = nil }()
964967
opts := *c.tempTransactionOptions
@@ -971,7 +974,7 @@ func (c *conn) getTransactionOptions() ReadWriteTransactionOptions {
971974
c.execOptions.TransactionOptions.TransactionTag = ""
972975
}()
973976
txOpts := ReadWriteTransactionOptions{
974-
TransactionOptions: c.execOptions.TransactionOptions,
977+
TransactionOptions: execOptions.TransactionOptions,
975978
DisableInternalRetries: !c.RetryAbortsInternally(),
976979
}
977980
// Only use the default isolation level from the connection if the ExecOptions
@@ -1019,7 +1022,7 @@ func (c *conn) Begin() (driver.Tx, error) {
10191022
return c.BeginTx(context.Background(), driver.TxOptions{})
10201023
}
10211024

1022-
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
1025+
func (c *conn) BeginTx(ctx context.Context, driverOpts driver.TxOptions) (driver.Tx, error) {
10231026
if c.resetForRetry {
10241027
c.resetForRetry = false
10251028
return c.tx, nil
@@ -1033,26 +1036,27 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
10331036

10341037
readOnlyTxOpts := c.getReadOnlyTransactionOptions()
10351038
batchReadOnlyTxOpts := c.getBatchReadOnlyTransactionOptions()
1036-
readWriteTransactionOptions := c.getTransactionOptions()
1039+
execOptions := c.execOptions
10371040
if c.inTransaction() {
10381041
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "already in a transaction"))
10391042
}
10401043
if c.inBatch() {
10411044
return nil, status.Error(codes.FailedPrecondition, "This connection has an active batch. Run or abort the batch before starting a new transaction.")
10421045
}
10431046

1047+
isolationLevelFromTxOpts := spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED
10441048
// Determine whether internal retries have been disabled using a special
10451049
// value for the transaction isolation level.
10461050
disableRetryAborts := false
10471051
batchReadOnly := false
1048-
sil := opts.Isolation >> 8
1049-
opts.Isolation = opts.Isolation - sil<<8
1050-
if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) {
1051-
level, err := toProtoIsolationLevel(sql.IsolationLevel(opts.Isolation))
1052+
sil := driverOpts.Isolation >> 8
1053+
driverOpts.Isolation = driverOpts.Isolation - sil<<8
1054+
if driverOpts.Isolation != driver.IsolationLevel(sql.LevelDefault) {
1055+
level, err := toProtoIsolationLevel(sql.IsolationLevel(driverOpts.Isolation))
10521056
if err != nil {
10531057
return nil, err
10541058
}
1055-
readWriteTransactionOptions.TransactionOptions.IsolationLevel = level
1059+
isolationLevelFromTxOpts = level
10561060
}
10571061
if sil > 0 {
10581062
switch spannerIsolationLevel(sil) {
@@ -1064,11 +1068,11 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
10641068
// ignore
10651069
}
10661070
}
1067-
if batchReadOnly && !opts.ReadOnly {
1071+
if batchReadOnly && !driverOpts.ReadOnly {
10681072
return nil, status.Error(codes.InvalidArgument, "levelBatchReadOnly can only be used for read-only transactions")
10691073
}
10701074

1071-
if opts.ReadOnly {
1075+
if driverOpts.ReadOnly {
10721076
var logger *slog.Logger
10731077
var ro *spanner.ReadOnlyTransaction
10741078
var bo *spanner.BatchReadOnlyTransaction
@@ -1106,7 +1110,23 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
11061110
return c.tx, nil
11071111
}
11081112

1109-
tx, err := spanner.NewReadWriteStmtBasedTransactionWithOptions(ctx, c.client, readWriteTransactionOptions.TransactionOptions)
1113+
opts := spanner.TransactionOptions{}
1114+
if c.tempTransactionOptions != nil {
1115+
opts = c.tempTransactionOptions.TransactionOptions
1116+
}
1117+
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.BeginTransactionOption)
1118+
tempCloseFunc := func() {}
1119+
if c.tempTransactionOptions != nil && c.tempTransactionOptions.close != nil {
1120+
tempCloseFunc = c.tempTransactionOptions.close
1121+
}
1122+
disableInternalRetries := !c.RetryAbortsInternally()
1123+
if c.tempTransactionOptions != nil {
1124+
disableInternalRetries = c.tempTransactionOptions.DisableInternalRetries
1125+
}
1126+
1127+
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions {
1128+
return c.effectiveTransactionOptions(isolationLevelFromTxOpts, execOptions)
1129+
})
11101130
if err != nil {
11111131
return nil, err
11121132
}
@@ -1117,9 +1137,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
11171137
logger: logger,
11181138
rwTx: tx,
11191139
close: func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) {
1120-
if readWriteTransactionOptions.close != nil {
1121-
readWriteTransactionOptions.close()
1122-
}
1140+
tempCloseFunc()
11231141
c.prevTx = c.tx
11241142
c.tx = nil
11251143
if commitErr == nil {
@@ -1134,12 +1152,21 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
11341152
}
11351153
},
11361154
// Disable internal retries if any of these options have been set.
1137-
retryAborts: !readWriteTransactionOptions.DisableInternalRetries && !disableRetryAborts,
1155+
retryAborts: !disableInternalRetries && !disableRetryAborts,
11381156
}
11391157
c.commitResponse = nil
11401158
return c.tx, nil
11411159
}
11421160

1161+
func (c *conn) effectiveTransactionOptions(isolationLevelFromTxOpts spannerpb.TransactionOptions_IsolationLevel, execOptions ExecOptions) spanner.TransactionOptions {
1162+
readWriteTransactionOptions := c.getTransactionOptions(execOptions)
1163+
res := readWriteTransactionOptions.TransactionOptions
1164+
if isolationLevelFromTxOpts != spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED {
1165+
res.IsolationLevel = isolationLevelFromTxOpts
1166+
}
1167+
return res
1168+
}
1169+
11431170
func (c *conn) convertDefaultBeginTransactionOption(opt spanner.BeginTransactionOption) spanner.BeginTransactionOption {
11441171
if opt == spanner.DefaultBeginTransaction {
11451172
if propertyBeginTransactionOption.GetValueOrDefault(c.state) == spanner.DefaultBeginTransaction {

driver_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ func TestExtractDnsParts(t *testing.T) {
284284
if err != nil {
285285
t.Fatalf("failed to get connector for %q: %v", tc.input, err)
286286
}
287-
if diff := cmp.Diff(conn.spannerClientConfig, tc.wantSpannerConfig, cmpopts.IgnoreUnexported(spanner.ClientConfig{}, spanner.SessionPoolConfig{}, spanner.InactiveTransactionRemovalOptions{}, spannerpb.ExecuteSqlRequest_QueryOptions{})); diff != "" {
287+
if diff := cmp.Diff(conn.spannerClientConfig, tc.wantSpannerConfig, cmpopts.IgnoreUnexported(spanner.ClientConfig{}, spanner.SessionPoolConfig{}, spanner.InactiveTransactionRemovalOptions{}, spannerpb.ExecuteSqlRequest_QueryOptions{}, spanner.TransactionOptions{})); diff != "" {
288288
t.Errorf("connector Spanner client config mismatch for %q\n%v", tc.input, diff)
289289
}
290290
actualConfig := conn.connectorConfig

driver_with_mockserver_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,9 @@ func TestSimpleReadWriteTransaction(t *testing.T) {
660660
t.Fatalf("commit requests count mismatch\n Got: %v\nWant: %v", g, w)
661661
}
662662
commitReq := commitRequests[0].(*sppb.CommitRequest)
663+
if commitReq.MaxCommitDelay == nil {
664+
t.Fatal("missing max commit delay for CommitRequest")
665+
}
663666
if g, w := commitReq.MaxCommitDelay.Nanos, int32(time.Millisecond*10); g != w {
664667
t.Fatalf("max_commit_delay mismatch\n Got: %v\nWant: %v", g, w)
665668
}

go.mod

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ go 1.24
44

55
toolchain go1.25.0
66

7+
replace cloud.google.com/go/spanner => ../google-cloud-go/spanner
8+
79
require (
810
cloud.google.com/go v0.121.6
911
cloud.google.com/go/longrunning v0.6.7
@@ -14,7 +16,7 @@ require (
1416
github.com/googleapis/gax-go/v2 v2.15.0
1517
github.com/hashicorp/golang-lru/v2 v2.0.7
1618
google.golang.org/api v0.247.0
17-
google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a
19+
google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c
1820
google.golang.org/grpc v1.74.2
1921
google.golang.org/protobuf v1.36.7
2022
)
@@ -60,5 +62,5 @@ require (
6062
golang.org/x/text v0.28.0 // indirect
6163
golang.org/x/time v0.12.0 // indirect
6264
google.golang.org/genproto v0.0.0-20250804133106-a7a43d27e69b // indirect
63-
google.golang.org/genproto/googleapis/api v0.0.0-20250804133106-a7a43d27e69b // indirect
65+
google.golang.org/genproto/googleapis/api v0.0.0-20250818200422-3122310a409c // indirect
6466
)

0 commit comments

Comments
 (0)