Skip to content

Commit c0781a7

Browse files
authored
chore: add generic transactional connection state (#493)
* chore: add generic transactional connection state Adds data structures for generic transactional connection state. These structures will be used to keep all connection state in one place, making it easier to add new connection variables. This also adds support for transactional connection state; Changes that are made during a transaction are only persisted if the transaction is committed. It also allows for setting temporary (local) values during a transaction. This change is the first step in a multi-step process for moving all connection variables into a generic structure. Following changes will move the other connection variables into this structure, and will add support for executing `set local ...` statements. * fix: return error for invalid value
1 parent 2787305 commit c0781a7

12 files changed

+1289
-36
lines changed

.github/workflows/unit-tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ jobs:
1919
uses: actions/checkout@v5
2020
- name: Run unit tests
2121
run: go test -race -short
22+
- name: Run connection state unit tests
23+
run: go test -race -short
24+
working-directory: connectionstate
2225

2326
lint:
2427
runs-on: ubuntu-latest

client_side_statement_test.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@ import (
2525
"cloud.google.com/go/spanner/apiv1/spannerpb"
2626
"github.com/google/go-cmp/cmp"
2727
"github.com/google/go-cmp/cmp/cmpopts"
28+
"github.com/googleapis/go-sql-spanner/connectionstate"
2829
"google.golang.org/grpc/codes"
2930
"google.golang.org/protobuf/types/known/structpb"
3031
)
3132

3233
func TestStatementExecutor_StartBatchDdl(t *testing.T) {
33-
c := &conn{retryAborts: true, logger: noopLogger}
34+
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
3435
s := &statementExecutor{}
3536
ctx := context.Background()
3637

@@ -61,7 +62,7 @@ func TestStatementExecutor_StartBatchDdl(t *testing.T) {
6162
}
6263

6364
func TestStatementExecutor_StartBatchDml(t *testing.T) {
64-
c := &conn{retryAborts: true, logger: noopLogger}
65+
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
6566
s := &statementExecutor{}
6667
ctx := context.Background()
6768

@@ -98,7 +99,7 @@ func TestStatementExecutor_StartBatchDml(t *testing.T) {
9899
}
99100

100101
func TestStatementExecutor_RetryAbortsInternally(t *testing.T) {
101-
c := &conn{retryAborts: true, logger: noopLogger}
102+
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
102103
s := &statementExecutor{}
103104
ctx := context.Background()
104105
for i, test := range []struct {
@@ -154,7 +155,7 @@ func TestStatementExecutor_RetryAbortsInternally(t *testing.T) {
154155
}
155156

156157
func TestStatementExecutor_AutocommitDmlMode(t *testing.T) {
157-
c := &conn{logger: noopLogger, connector: &connector{}}
158+
c := &conn{logger: noopLogger, connector: &connector{}, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
158159
_ = c.ResetSession(context.Background())
159160
s := &statementExecutor{}
160161
ctx := context.Background()
@@ -211,7 +212,7 @@ func TestStatementExecutor_AutocommitDmlMode(t *testing.T) {
211212
}
212213

213214
func TestStatementExecutor_ReadOnlyStaleness(t *testing.T) {
214-
c := &conn{logger: noopLogger}
215+
c := &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
215216
s := &statementExecutor{}
216217
ctx := context.Background()
217218
for i, test := range []struct {
@@ -282,7 +283,7 @@ func TestStatementExecutor_ReadOnlyStaleness(t *testing.T) {
282283
func TestShowCommitTimestamp(t *testing.T) {
283284
t.Parallel()
284285

285-
c := &conn{retryAborts: true, logger: noopLogger}
286+
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
286287
s := &statementExecutor{}
287288
ctx := context.Background()
288289

@@ -328,7 +329,7 @@ func TestShowCommitTimestamp(t *testing.T) {
328329
}
329330

330331
func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) {
331-
c := &conn{retryAborts: true, logger: noopLogger}
332+
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
332333
s := &statementExecutor{}
333334
ctx := context.Background()
334335
for i, test := range []struct {
@@ -384,7 +385,7 @@ func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) {
384385
}
385386

386387
func TestStatementExecutor_MaxCommitDelay(t *testing.T) {
387-
c := &conn{logger: noopLogger}
388+
c := &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
388389
s := &statementExecutor{}
389390
ctx := context.Background()
390391
for i, test := range []struct {
@@ -457,7 +458,7 @@ func TestStatementExecutor_SetTransactionTag(t *testing.T) {
457458
{"", "tag-with-missing-opening-quote'", true},
458459
{"", "'tag-with-missing-closing-quote", true},
459460
} {
460-
c := &conn{retryAborts: true, logger: noopLogger}
461+
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
461462
s := &statementExecutor{}
462463

463464
it, err := s.ShowTransactionTag(ctx, c, "", ExecOptions{}, nil)
@@ -517,7 +518,7 @@ func TestStatementExecutor_SetTransactionTag(t *testing.T) {
517518

518519
func TestStatementExecutor_UsesExecOptions(t *testing.T) {
519520
ctx := context.Background()
520-
c := &conn{retryAborts: true, logger: noopLogger}
521+
c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})}
521522
s := &statementExecutor{}
522523

523524
it, err := s.ShowTransactionTag(ctx, c, "", ExecOptions{DecodeOption: DecodeOptionProto, ReturnResultSetMetadata: true, ReturnResultSetStats: true}, nil)

conn.go

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
adminapi "cloud.google.com/go/spanner/admin/database/apiv1"
2828
adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
2929
"cloud.google.com/go/spanner/apiv1/spannerpb"
30+
"github.com/googleapis/go-sql-spanner/connectionstate"
3031
"google.golang.org/api/iterator"
3132
"google.golang.org/grpc/codes"
3233
"google.golang.org/grpc/status"
@@ -231,6 +232,8 @@ type conn struct {
231232
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, *spanner.CommitResponse, error)
232233
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error)
233234

235+
// state contains the current ConnectionState for this connection.
236+
state *connectionstate.ConnectionState
234237
// batch is the currently active DDL or DML batch on this connection.
235238
batch *batch
236239
// autoBatchDml determines whether DML statements should automatically
@@ -244,11 +247,6 @@ type conn struct {
244247
// statements was correct.
245248
autoBatchDmlUpdateCountVerification bool
246249

247-
// autocommitDMLMode determines the type of DML to use when a single DML
248-
// statement is executed on a connection. The default is Transactional, but
249-
// it can also be set to PartitionedNonAtomic to execute the statement as
250-
// Partitioned DML.
251-
autocommitDMLMode AutocommitDMLMode
252250
// readOnlyStaleness is used for queries in autocommit mode and for read-only transactions.
253251
readOnlyStaleness spanner.TimestampBound
254252
// isolationLevel determines the default isolation level that is used for read/write
@@ -308,7 +306,7 @@ func (c *conn) setRetryAbortsInternally(retry bool) (driver.Result, error) {
308306
}
309307

310308
func (c *conn) AutocommitDMLMode() AutocommitDMLMode {
311-
return c.autocommitDMLMode
309+
return propertyAutocommitDmlMode.GetValueOrDefault(c.state)
312310
}
313311

314312
func (c *conn) SetAutocommitDMLMode(mode AutocommitDMLMode) error {
@@ -320,7 +318,9 @@ func (c *conn) SetAutocommitDMLMode(mode AutocommitDMLMode) error {
320318
}
321319

322320
func (c *conn) setAutocommitDMLMode(mode AutocommitDMLMode) (driver.Result, error) {
323-
c.autocommitDMLMode = mode
321+
if err := propertyAutocommitDmlMode.SetValue(c.state, mode, connectionstate.ContextUser); err != nil {
322+
return nil, err
323+
}
324324
return driver.ResultNoRows, nil
325325
}
326326

@@ -689,8 +689,9 @@ func (c *conn) ResetSession(_ context.Context) error {
689689
c.retryAborts = c.connector.retryAbortsInternally
690690
c.isolationLevel = c.connector.connectorConfig.IsolationLevel
691691
c.beginTransactionOption = c.connector.connectorConfig.BeginTransactionOption
692+
693+
_ = c.state.Reset(connectionstate.ContextUser)
692694
// TODO: Reset the following fields to the connector default
693-
c.autocommitDMLMode = Transactional
694695
c.readOnlyStaleness = spanner.TimestampBound{}
695696
c.execOptions = ExecOptions{
696697
DecodeToNativeArrays: c.connector.connectorConfig.DecodeToNativeArrays,
@@ -887,7 +888,7 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOp
887888
c.batch.statements = append(c.batch.statements, ss)
888889
res = &result{}
889890
} else {
890-
dmlMode := c.autocommitDMLMode
891+
dmlMode := c.AutocommitDMLMode()
891892
if execOptions.AutocommitDMLMode != Unspecified {
892893
dmlMode = execOptions.AutocommitDMLMode
893894
}
@@ -1015,6 +1016,13 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
10151016
c.resetForRetry = false
10161017
return c.tx, nil
10171018
}
1019+
// Also start a transaction on the ConnectionState if the BeginTx call was successful.
1020+
defer func() {
1021+
if c.tx != nil {
1022+
_ = c.state.Begin()
1023+
}
1024+
}()
1025+
10181026
readOnlyTxOpts := c.getReadOnlyTransactionOptions()
10191027
batchReadOnlyTxOpts := c.getBatchReadOnlyTransactionOptions()
10201028
readWriteTransactionOptions := c.getTransactionOptions()
@@ -1072,13 +1080,18 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
10721080
roTx: ro,
10731081
boTx: bo,
10741082
logger: logger,
1075-
close: func() {
1083+
close: func(result txResult) {
10761084
if batchReadOnlyTxOpts.close != nil {
10771085
batchReadOnlyTxOpts.close()
10781086
}
10791087
if readOnlyTxOpts.close != nil {
10801088
readOnlyTxOpts.close()
10811089
}
1090+
if result == txResultCommit {
1091+
_ = c.state.Commit()
1092+
} else {
1093+
_ = c.state.Rollback()
1094+
}
10821095
c.tx = nil
10831096
},
10841097
}
@@ -1095,14 +1108,21 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
10951108
conn: c,
10961109
logger: logger,
10971110
rwTx: tx,
1098-
close: func(commitResponse *spanner.CommitResponse, commitErr error) {
1111+
close: func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) {
10991112
if readWriteTransactionOptions.close != nil {
11001113
readWriteTransactionOptions.close()
11011114
}
11021115
c.prevTx = c.tx
11031116
c.tx = nil
11041117
if commitErr == nil {
11051118
c.commitResponse = commitResponse
1119+
if result == txResultCommit {
1120+
_ = c.state.Commit()
1121+
} else {
1122+
_ = c.state.Rollback()
1123+
}
1124+
} else {
1125+
_ = c.state.Rollback()
11061126
}
11071127
},
11081128
// Disable internal retries if any of these options have been set.

conn_with_mockserver_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"cloud.google.com/go/spanner"
2525
"cloud.google.com/go/spanner/apiv1/spannerpb"
26+
"github.com/googleapis/go-sql-spanner/connectionstate"
2627
"github.com/googleapis/go-sql-spanner/testutil"
2728
)
2829

@@ -448,3 +449,105 @@ func TestSetRetryAbortsInternallyInActiveTransaction(t *testing.T) {
448449
}
449450
_ = tx.Rollback()
450451
}
452+
453+
func TestSetAutocommitDMLMode(t *testing.T) {
454+
t.Parallel()
455+
456+
for _, tp := range []connectionstate.Type{connectionstate.TypeTransactional, connectionstate.TypeNonTransactional} {
457+
db, _, teardown := setupTestDBConnectionWithConnectorConfig(t, ConnectorConfig{
458+
Project: "p",
459+
Instance: "i",
460+
Database: "d",
461+
ConnectionStateType: tp,
462+
})
463+
defer teardown()
464+
465+
conn, err := db.Conn(context.Background())
466+
if err != nil {
467+
t.Fatal(err)
468+
}
469+
defer func() { _ = conn.Close() }()
470+
471+
_ = conn.Raw(func(driverConn interface{}) error {
472+
c, _ := driverConn.(SpannerConn)
473+
if g, w := c.AutocommitDMLMode(), Transactional; g != w {
474+
t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w)
475+
}
476+
if err := c.SetAutocommitDMLMode(PartitionedNonAtomic); err != nil {
477+
t.Fatal(err)
478+
}
479+
if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w {
480+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
481+
}
482+
return nil
483+
})
484+
485+
// Set the value in a transaction and commit.
486+
tx, err := conn.BeginTx(context.Background(), &sql.TxOptions{})
487+
if err != nil {
488+
t.Fatal(err)
489+
}
490+
_ = conn.Raw(func(driverConn interface{}) error {
491+
c, _ := driverConn.(SpannerConn)
492+
// The value should be the same as before the transaction started.
493+
if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w {
494+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
495+
}
496+
// Changes in a transaction should be visible in the transaction.
497+
if err := c.SetAutocommitDMLMode(Transactional); err != nil {
498+
t.Fatal(err)
499+
}
500+
if g, w := c.AutocommitDMLMode(), Transactional; g != w {
501+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
502+
}
503+
return nil
504+
})
505+
// Committing the transaction should make the change durable (and is a no-op if the connection state type is
506+
// non-transactional).
507+
if err := tx.Commit(); err != nil {
508+
t.Fatal(err)
509+
}
510+
_ = conn.Raw(func(driverConn interface{}) error {
511+
c, _ := driverConn.(SpannerConn)
512+
if g, w := c.AutocommitDMLMode(), Transactional; g != w {
513+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
514+
}
515+
return nil
516+
})
517+
518+
// Set the value in a transaction and rollback.
519+
tx, err = conn.BeginTx(context.Background(), &sql.TxOptions{})
520+
if err != nil {
521+
t.Fatal(err)
522+
}
523+
_ = conn.Raw(func(driverConn interface{}) error {
524+
c, _ := driverConn.(SpannerConn)
525+
if err := c.SetAutocommitDMLMode(PartitionedNonAtomic); err != nil {
526+
t.Fatal(err)
527+
}
528+
if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w {
529+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
530+
}
531+
return nil
532+
})
533+
// Rolling back the transaction will undo the change if the connection state is transactional.
534+
// In case of non-transactional state, the rollback does not have an effect, as the state change was persisted
535+
// directly when SetAutocommitDMLMode was called.
536+
if err := tx.Rollback(); err != nil {
537+
t.Fatal(err)
538+
}
539+
_ = conn.Raw(func(driverConn interface{}) error {
540+
c, _ := driverConn.(SpannerConn)
541+
var expected AutocommitDMLMode
542+
if tp == connectionstate.TypeTransactional {
543+
expected = Transactional
544+
} else {
545+
expected = PartitionedNonAtomic
546+
}
547+
if g, w := c.AutocommitDMLMode(), expected; g != w {
548+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
549+
}
550+
return nil
551+
})
552+
}
553+
}

connection_properties.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package spannerdriver
16+
17+
import "github.com/googleapis/go-sql-spanner/connectionstate"
18+
19+
// connectionProperties contains all supported connection properties for Spanner.
20+
// These properties are added to all connectionstate.ConnectionState instances that are created for Spanner connections.
21+
var connectionProperties = map[string]connectionstate.ConnectionProperty{}
22+
23+
// The following variables define the various connectionstate.ConnectionProperty instances that are supported and used
24+
// by the Spanner database/sql driver. They are defined as global variables, so they can be used directly in the driver
25+
// to get/set the state of exactly that property.
26+
27+
var propertyConnectionStateType = createConnectionProperty(
28+
"connection_state_type",
29+
"The type of connection state to use for this connection. Can only be set at start up. "+
30+
"If no value is set, then the database dialect default will be used, "+
31+
"which is NON_TRANSACTIONAL for GoogleSQL and TRANSACTIONAL for PostgreSQL.",
32+
connectionstate.TypeDefault,
33+
[]connectionstate.Type{connectionstate.TypeDefault, connectionstate.TypeTransactional, connectionstate.TypeNonTransactional},
34+
connectionstate.ContextStartup,
35+
)
36+
var propertyAutocommitDmlMode = createConnectionProperty(
37+
"autocommit_dml_mode",
38+
"Determines the transaction type that is used to execute DML statements when the connection is in auto-commit mode.",
39+
Transactional,
40+
[]AutocommitDMLMode{Transactional, PartitionedNonAtomic},
41+
connectionstate.ContextUser,
42+
)
43+
44+
func createConnectionProperty[T comparable](name, description string, defaultValue T, validValues []T, context connectionstate.Context) *connectionstate.TypedConnectionProperty[T] {
45+
prop := connectionstate.CreateConnectionProperty(name, description, defaultValue, validValues, context)
46+
connectionProperties[prop.Key()] = prop
47+
return prop
48+
}
49+
50+
func createInitialConnectionState(connectionStateType connectionstate.Type, initialValues map[string]connectionstate.ConnectionPropertyValue) *connectionstate.ConnectionState {
51+
state, _ := connectionstate.NewConnectionState(connectionStateType, connectionProperties, initialValues)
52+
return state
53+
}

0 commit comments

Comments
 (0)