Skip to content

Commit 900e003

Browse files
committed
chore: remove transaction id from api
1 parent 2bf6e3e commit 900e003

File tree

17 files changed

+186
-225
lines changed

17 files changed

+186
-225
lines changed

conn.go

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,17 @@ func sum(affected []int64) int64 {
694694
return sum
695695
}
696696

697+
func (c *conn) WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) {
698+
if c.inTransaction() {
699+
return nil, c.BufferWrite(ms)
700+
}
701+
ts, err := c.Apply(ctx, ms)
702+
if err != nil {
703+
return nil, err
704+
}
705+
return &spanner.CommitResponse{CommitTs: ts}, nil
706+
}
707+
697708
func (c *conn) Apply(ctx context.Context, ms []*spanner.Mutation, opts ...spanner.ApplyOption) (commitTimestamp time.Time, err error) {
698709
if c.inTransaction() {
699710
return time.Time{}, spanner.ToSpannerError(
@@ -1091,6 +1102,26 @@ func (c *conn) getBatchReadOnlyTransactionOptions() BatchReadOnlyTransactionOpti
10911102
return BatchReadOnlyTransactionOptions{TimestampBound: c.ReadOnlyStaleness()}
10921103
}
10931104

1105+
func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTransactionOptions) (driver.Tx, error) {
1106+
c.withTempReadOnlyTransactionOptions(options)
1107+
tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true})
1108+
if err != nil {
1109+
c.withTempReadOnlyTransactionOptions(nil)
1110+
return nil, err
1111+
}
1112+
return tx, nil
1113+
}
1114+
1115+
func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWriteTransactionOptions) (driver.Tx, error) {
1116+
c.withTempTransactionOptions(options)
1117+
tx, err := c.BeginTx(ctx, driver.TxOptions{})
1118+
if err != nil {
1119+
c.withTempTransactionOptions(nil)
1120+
return nil, err
1121+
}
1122+
return tx, nil
1123+
}
1124+
10941125
func (c *conn) Begin() (driver.Tx, error) {
10951126
return c.BeginTx(context.Background(), driver.TxOptions{})
10961127
}
@@ -1274,18 +1305,21 @@ func (c *conn) inReadWriteTransaction() bool {
12741305
return false
12751306
}
12761307

1277-
func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) {
1308+
func (c *conn) Commit(ctx context.Context) (*spanner.CommitResponse, error) {
12781309
if !c.inTransaction() {
12791310
return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
12801311
}
12811312
// TODO: Pass in context to the tx.Commit() function.
12821313
if err := c.tx.Commit(); err != nil {
12831314
return nil, err
12841315
}
1285-
return c.CommitResponse()
1316+
1317+
// This will return either the commit response or nil, depending on whether the transaction was a
1318+
// read/write transaction or a read-only transaction.
1319+
return propertyCommitResponse.GetValueOrDefault(c.state), nil
12861320
}
12871321

1288-
func (c *conn) rollback(ctx context.Context) error {
1322+
func (c *conn) Rollback(ctx context.Context) error {
12891323
if !c.inTransaction() {
12901324
return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
12911325
}

spannerlib/api/connection.go

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ package api
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"fmt"
78
"strings"
89
"sync"
910
"sync/atomic"
10-
"time"
1111

1212
"cloud.google.com/go/spanner"
1313
"cloud.google.com/go/spanner/apiv1/spannerpb"
@@ -26,18 +26,18 @@ func CloseConnection(poolId, connId int64) error {
2626
return conn.close()
2727
}
2828

29-
func Apply(poolId, connId int64, mutations *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) {
29+
func WriteMutations(poolId, connId int64, mutations *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) {
3030
conn, err := findConnection(poolId, connId)
3131
if err != nil {
3232
return nil, err
3333
}
34-
return conn.apply(mutations)
34+
return conn.writeMutations(mutations)
3535
}
3636

37-
func BeginTransaction(poolId, connId int64, txOpts *spannerpb.TransactionOptions) (int64, error) {
37+
func BeginTransaction(poolId, connId int64, txOpts *spannerpb.TransactionOptions) error {
3838
conn, err := findConnection(poolId, connId)
3939
if err != nil {
40-
return 0, err
40+
return err
4141
}
4242
return conn.BeginTransaction(txOpts)
4343
}
@@ -62,12 +62,18 @@ type Connection struct {
6262
results *sync.Map
6363
resultsIdx atomic.Int64
6464

65-
transactions *sync.Map
66-
transactionsIdx atomic.Int64
67-
6865
backend *sql.Conn
6966
}
7067

68+
// spannerConn is an internal interface that contains the internal functions that are used by this API.
69+
type spannerConn interface {
70+
WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error)
71+
BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions) (driver.Tx, error)
72+
BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions) (driver.Tx, error)
73+
Commit(ctx context.Context) (*spanner.CommitResponse, error)
74+
Rollback(ctx context.Context) error
75+
}
76+
7177
type queryExecutor interface {
7278
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
7379
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
@@ -79,19 +85,14 @@ func (conn *Connection) close() error {
7985
_ = res.Close()
8086
return true
8187
})
82-
conn.transactions.Range(func(key, value interface{}) bool {
83-
res := value.(*transaction)
84-
_ = res.Close()
85-
return true
86-
})
8788
err := conn.backend.Close()
8889
if err != nil {
8990
return err
9091
}
9192
return nil
9293
}
9394

94-
func (conn *Connection) apply(mutation *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) {
95+
func (conn *Connection) writeMutations(mutation *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) {
9596
ctx := context.Background()
9697
mutations := make([]*spanner.Mutation, 0, len(mutation.Mutations))
9798
for _, m := range mutation.Mutations {
@@ -101,47 +102,59 @@ func (conn *Connection) apply(mutation *spannerpb.BatchWriteRequest_MutationGrou
101102
}
102103
mutations = append(mutations, spannerMutation)
103104
}
104-
var commitTimestamp time.Time
105+
var commitResponse *spanner.CommitResponse
105106
if err := conn.backend.Raw(func(driverConn any) (err error) {
106-
spannerConn, _ := driverConn.(spannerdriver.SpannerConn)
107-
commitTimestamp, err = spannerConn.Apply(ctx, mutations)
107+
sc, _ := driverConn.(spannerConn)
108+
commitResponse, err = sc.WriteMutations(ctx, mutations)
108109
return err
109110
}); err != nil {
110111
return nil, err
111112
}
113+
114+
// The commit response is nil if the connection is currently in a transaction.
115+
if commitResponse == nil {
116+
return nil, nil
117+
}
112118
response := spannerpb.CommitResponse{
113-
CommitTimestamp: timestamppb.New(commitTimestamp),
119+
CommitTimestamp: timestamppb.New(commitResponse.CommitTs),
114120
}
115121
return &response, nil
116122
}
117123

118-
func (conn *Connection) BeginTransaction(txOpts *spannerpb.TransactionOptions) (int64, error) {
119-
var tx *sql.Tx
124+
func (conn *Connection) BeginTransaction(txOpts *spannerpb.TransactionOptions) error {
120125
var err error
126+
ctx := context.Background()
121127
if txOpts.GetReadOnly() != nil {
122-
tx, err = spannerdriver.BeginReadOnlyTransactionOnConn(
123-
context.Background(), conn.backend, convertToReadOnlyOpts(txOpts))
128+
return conn.beginReadOnlyTransaction(ctx, convertToReadOnlyOpts(txOpts))
124129
} else if txOpts.GetPartitionedDml() != nil {
125130
err = spanner.ToSpannerError(status.Error(codes.InvalidArgument, "transaction type not supported"))
126131
} else {
127-
tx, err = spannerdriver.BeginReadWriteTransactionOnConn(
128-
context.Background(), conn.backend, convertToReadWriteTransactionOptions(txOpts))
132+
return conn.beginReadWriteTransaction(ctx, convertToReadWriteTransactionOptions(txOpts))
129133
}
130134
if err != nil {
131-
return 0, err
132-
}
133-
id := conn.transactionsIdx.Add(1)
134-
res := &transaction{
135-
backend: tx,
136-
conn: conn,
137-
txOpts: txOpts,
135+
return err
138136
}
139-
conn.transactions.Store(id, res)
140-
return id, nil
137+
return nil
138+
}
139+
140+
func (conn *Connection) beginReadOnlyTransaction(ctx context.Context, opts *spannerdriver.ReadOnlyTransactionOptions) error {
141+
return conn.backend.Raw(func(driverConn any) (err error) {
142+
sc, _ := driverConn.(spannerConn)
143+
_, err = sc.BeginReadOnlyTransaction(ctx, opts)
144+
return err
145+
})
146+
}
147+
148+
func (conn *Connection) beginReadWriteTransaction(ctx context.Context, opts *spannerdriver.ReadWriteTransactionOptions) error {
149+
return conn.backend.Raw(func(driverConn any) (err error) {
150+
sc, _ := driverConn.(spannerConn)
151+
_, err = sc.BeginReadWriteTransaction(ctx, opts)
152+
return err
153+
})
141154
}
142155

143-
func convertToReadOnlyOpts(txOpts *spannerpb.TransactionOptions) spannerdriver.ReadOnlyTransactionOptions {
144-
return spannerdriver.ReadOnlyTransactionOptions{
156+
func convertToReadOnlyOpts(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadOnlyTransactionOptions {
157+
return &spannerdriver.ReadOnlyTransactionOptions{
145158
TimestampBound: convertTimestampBound(txOpts),
146159
}
147160
}
@@ -162,12 +175,12 @@ func convertTimestampBound(txOpts *spannerpb.TransactionOptions) spanner.Timesta
162175
return spanner.TimestampBound{}
163176
}
164177

165-
func convertToReadWriteTransactionOptions(txOpts *spannerpb.TransactionOptions) spannerdriver.ReadWriteTransactionOptions {
178+
func convertToReadWriteTransactionOptions(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadWriteTransactionOptions {
166179
readLockMode := spannerpb.TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED
167180
if txOpts.GetReadWrite() != nil {
168181
readLockMode = txOpts.GetReadWrite().GetReadLockMode()
169182
}
170-
return spannerdriver.ReadWriteTransactionOptions{
183+
return &spannerdriver.ReadWriteTransactionOptions{
171184
TransactionOptions: spanner.TransactionOptions{
172185
IsolationLevel: txOpts.GetIsolationLevel(),
173186
ReadLockMode: readLockMode,

spannerlib/api/pool.go

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ func CreateConnection(poolId int64) (int64, error) {
6969
}
7070
id := poolsIdx.Add(1)
7171
conn := &Connection{
72-
backend: sqlConn,
73-
results: &sync.Map{},
74-
transactions: &sync.Map{},
72+
backend: sqlConn,
73+
results: &sync.Map{},
7574
}
7675
pool.connections.Store(id, conn)
7776

@@ -104,16 +103,3 @@ func findRows(poolId, connId, rowsId int64) (*rows, error) {
104103
res := r.(*rows)
105104
return res, nil
106105
}
107-
108-
func findTx(poolId, connId, txId int64) (*transaction, error) {
109-
conn, err := findConnection(poolId, connId)
110-
if err != nil {
111-
return nil, err
112-
}
113-
r, ok := conn.transactions.Load(txId)
114-
if !ok {
115-
return nil, fmt.Errorf("tx %v not found", txId)
116-
}
117-
res := r.(*transaction)
118-
return res, nil
119-
}

spannerlib/api/transaction.go

Lines changed: 19 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package api
22

33
import (
4+
"context"
45
"database/sql"
56

67
"cloud.google.com/go/spanner"
@@ -9,46 +10,20 @@ import (
910
"google.golang.org/protobuf/types/known/timestamppb"
1011
)
1112

12-
func BufferWrite(poolId, connId, txId int64, mutations *spannerpb.BatchWriteRequest_MutationGroup) error {
13-
tx, err := findTx(poolId, connId, txId)
14-
if err != nil {
15-
return err
16-
}
17-
return tx.bufferWrite(mutations)
18-
}
19-
20-
func ExecuteTransaction(poolId, connId, txId int64, request *spannerpb.ExecuteSqlRequest) (int64, error) {
21-
tx, err := findTx(poolId, connId, txId)
22-
if err != nil {
23-
return 0, err
24-
}
25-
return tx.Execute(request)
26-
}
27-
28-
func Commit(poolId, connId, txId int64) (*spannerpb.CommitResponse, error) {
29-
tx, err := findTx(poolId, connId, txId)
30-
if err != nil {
31-
return nil, err
32-
}
13+
func Commit(poolId, connId int64) (*spannerpb.CommitResponse, error) {
3314
conn, err := findConnection(poolId, connId)
3415
if err != nil {
3516
return nil, err
3617
}
37-
conn.transactions.Delete(txId)
38-
return tx.Commit()
18+
return commit(conn)
3919
}
4020

41-
func Rollback(poolId, connId, txId int64) error {
42-
tx, err := findTx(poolId, connId, txId)
43-
if err != nil {
44-
return err
45-
}
21+
func Rollback(poolId, connId int64) error {
4622
conn, err := findConnection(poolId, connId)
4723
if err != nil {
4824
return err
4925
}
50-
conn.transactions.Delete(txId)
51-
return tx.Rollback()
26+
return rollback(conn)
5227
}
5328

5429
type transaction struct {
@@ -91,33 +66,30 @@ func (tx *transaction) Execute(statement *spannerpb.ExecuteSqlRequest) (int64, e
9166
return execute(tx.conn, tx.backend, statement)
9267
}
9368

94-
func (tx *transaction) Commit() (*spannerpb.CommitResponse, error) {
95-
tx.closed = true
96-
if err := tx.backend.Commit(); err != nil {
97-
return nil, err
98-
}
69+
func commit(conn *Connection) (*spannerpb.CommitResponse, error) {
9970
var response *spanner.CommitResponse
100-
if tx.txOpts.GetReadWrite() == nil {
101-
return &spannerpb.CommitResponse{}, nil
102-
}
103-
if err := tx.conn.backend.Raw(func(driverConn any) (err error) {
104-
spannerConn, _ := driverConn.(spannerdriver.SpannerConn)
105-
response, err = spannerConn.CommitResponse()
71+
if err := conn.backend.Raw(func(driverConn any) (err error) {
72+
spannerConn, _ := driverConn.(spannerConn)
73+
response, err = spannerConn.Commit(context.Background())
10674
if err != nil {
10775
return err
10876
}
10977
return nil
11078
}); err != nil {
11179
return nil, err
11280
}
81+
82+
// The commit response is nil for read-only transactions.
83+
if response == nil {
84+
return nil, nil
85+
}
11386
// TODO: Include commit stats
11487
return &spannerpb.CommitResponse{CommitTimestamp: timestamppb.New(response.CommitTs)}, nil
11588
}
11689

117-
func (tx *transaction) Rollback() error {
118-
tx.closed = true
119-
if err := tx.backend.Rollback(); err != nil {
120-
return err
121-
}
122-
return nil
90+
func rollback(conn *Connection) error {
91+
return conn.backend.Raw(func(driverConn any) (err error) {
92+
spannerConn, _ := driverConn.(spannerConn)
93+
return spannerConn.Rollback(context.Background())
94+
})
12395
}

0 commit comments

Comments
 (0)