@@ -3,11 +3,11 @@ package api
33import (
44 "context"
55 "database/sql"
6+ "database/sql/driver"
67 "fmt"
78 "strings"
89 "sync"
910 "sync/atomic"
10- "time"
1111
1212 "cloud.google.com/go/spanner"
1313 "cloud.google.com/go/spanner/apiv1/spannerpb"
@@ -26,18 +26,18 @@ func CloseConnection(poolId, connId int64) error {
2626 return conn .close ()
2727}
2828
29- func Apply (poolId , connId int64 , mutations * spannerpb.BatchWriteRequest_MutationGroup ) (* spannerpb.CommitResponse , error ) {
29+ func WriteMutations (poolId , connId int64 , mutations * spannerpb.BatchWriteRequest_MutationGroup ) (* spannerpb.CommitResponse , error ) {
3030 conn , err := findConnection (poolId , connId )
3131 if err != nil {
3232 return nil , err
3333 }
34- return conn .apply (mutations )
34+ return conn .writeMutations (mutations )
3535}
3636
37- func BeginTransaction (poolId , connId int64 , txOpts * spannerpb.TransactionOptions ) ( int64 , error ) {
37+ func BeginTransaction (poolId , connId int64 , txOpts * spannerpb.TransactionOptions ) error {
3838 conn , err := findConnection (poolId , connId )
3939 if err != nil {
40- return 0 , err
40+ return err
4141 }
4242 return conn .BeginTransaction (txOpts )
4343}
@@ -62,12 +62,18 @@ type Connection struct {
6262 results * sync.Map
6363 resultsIdx atomic.Int64
6464
65- transactions * sync.Map
66- transactionsIdx atomic.Int64
67-
6865 backend * sql.Conn
6966}
7067
68+ // spannerConn is an internal interface that contains the internal functions that are used by this API.
69+ type spannerConn interface {
70+ WriteMutations (ctx context.Context , ms []* spanner.Mutation ) (* spanner.CommitResponse , error )
71+ BeginReadOnlyTransaction (ctx context.Context , options * spannerdriver.ReadOnlyTransactionOptions ) (driver.Tx , error )
72+ BeginReadWriteTransaction (ctx context.Context , options * spannerdriver.ReadWriteTransactionOptions ) (driver.Tx , error )
73+ Commit (ctx context.Context ) (* spanner.CommitResponse , error )
74+ Rollback (ctx context.Context ) error
75+ }
76+
7177type queryExecutor interface {
7278 ExecContext (ctx context.Context , query string , args ... any ) (sql.Result , error )
7379 QueryContext (ctx context.Context , query string , args ... any ) (* sql.Rows , error )
@@ -79,19 +85,14 @@ func (conn *Connection) close() error {
7985 _ = res .Close ()
8086 return true
8187 })
82- conn .transactions .Range (func (key , value interface {}) bool {
83- res := value .(* transaction )
84- _ = res .Close ()
85- return true
86- })
8788 err := conn .backend .Close ()
8889 if err != nil {
8990 return err
9091 }
9192 return nil
9293}
9394
94- func (conn * Connection ) apply (mutation * spannerpb.BatchWriteRequest_MutationGroup ) (* spannerpb.CommitResponse , error ) {
95+ func (conn * Connection ) writeMutations (mutation * spannerpb.BatchWriteRequest_MutationGroup ) (* spannerpb.CommitResponse , error ) {
9596 ctx := context .Background ()
9697 mutations := make ([]* spanner.Mutation , 0 , len (mutation .Mutations ))
9798 for _ , m := range mutation .Mutations {
@@ -101,47 +102,59 @@ func (conn *Connection) apply(mutation *spannerpb.BatchWriteRequest_MutationGrou
101102 }
102103 mutations = append (mutations , spannerMutation )
103104 }
104- var commitTimestamp time. Time
105+ var commitResponse * spanner. CommitResponse
105106 if err := conn .backend .Raw (func (driverConn any ) (err error ) {
106- spannerConn , _ := driverConn .(spannerdriver. SpannerConn )
107- commitTimestamp , err = spannerConn . Apply (ctx , mutations )
107+ sc , _ := driverConn .(spannerConn )
108+ commitResponse , err = sc . WriteMutations (ctx , mutations )
108109 return err
109110 }); err != nil {
110111 return nil , err
111112 }
113+
114+ // The commit response is nil if the connection is currently in a transaction.
115+ if commitResponse == nil {
116+ return nil , nil
117+ }
112118 response := spannerpb.CommitResponse {
113- CommitTimestamp : timestamppb .New (commitTimestamp ),
119+ CommitTimestamp : timestamppb .New (commitResponse . CommitTs ),
114120 }
115121 return & response , nil
116122}
117123
118- func (conn * Connection ) BeginTransaction (txOpts * spannerpb.TransactionOptions ) (int64 , error ) {
119- var tx * sql.Tx
124+ func (conn * Connection ) BeginTransaction (txOpts * spannerpb.TransactionOptions ) error {
120125 var err error
126+ ctx := context .Background ()
121127 if txOpts .GetReadOnly () != nil {
122- tx , err = spannerdriver .BeginReadOnlyTransactionOnConn (
123- context .Background (), conn .backend , convertToReadOnlyOpts (txOpts ))
128+ return conn .beginReadOnlyTransaction (ctx , convertToReadOnlyOpts (txOpts ))
124129 } else if txOpts .GetPartitionedDml () != nil {
125130 err = spanner .ToSpannerError (status .Error (codes .InvalidArgument , "transaction type not supported" ))
126131 } else {
127- tx , err = spannerdriver .BeginReadWriteTransactionOnConn (
128- context .Background (), conn .backend , convertToReadWriteTransactionOptions (txOpts ))
132+ return conn .beginReadWriteTransaction (ctx , convertToReadWriteTransactionOptions (txOpts ))
129133 }
130134 if err != nil {
131- return 0 , err
132- }
133- id := conn .transactionsIdx .Add (1 )
134- res := & transaction {
135- backend : tx ,
136- conn : conn ,
137- txOpts : txOpts ,
135+ return err
138136 }
139- conn .transactions .Store (id , res )
140- return id , nil
137+ return nil
138+ }
139+
140+ func (conn * Connection ) beginReadOnlyTransaction (ctx context.Context , opts * spannerdriver.ReadOnlyTransactionOptions ) error {
141+ return conn .backend .Raw (func (driverConn any ) (err error ) {
142+ sc , _ := driverConn .(spannerConn )
143+ _ , err = sc .BeginReadOnlyTransaction (ctx , opts )
144+ return err
145+ })
146+ }
147+
148+ func (conn * Connection ) beginReadWriteTransaction (ctx context.Context , opts * spannerdriver.ReadWriteTransactionOptions ) error {
149+ return conn .backend .Raw (func (driverConn any ) (err error ) {
150+ sc , _ := driverConn .(spannerConn )
151+ _ , err = sc .BeginReadWriteTransaction (ctx , opts )
152+ return err
153+ })
141154}
142155
143- func convertToReadOnlyOpts (txOpts * spannerpb.TransactionOptions ) spannerdriver.ReadOnlyTransactionOptions {
144- return spannerdriver.ReadOnlyTransactionOptions {
156+ func convertToReadOnlyOpts (txOpts * spannerpb.TransactionOptions ) * spannerdriver.ReadOnlyTransactionOptions {
157+ return & spannerdriver.ReadOnlyTransactionOptions {
145158 TimestampBound : convertTimestampBound (txOpts ),
146159 }
147160}
@@ -162,12 +175,12 @@ func convertTimestampBound(txOpts *spannerpb.TransactionOptions) spanner.Timesta
162175 return spanner.TimestampBound {}
163176}
164177
165- func convertToReadWriteTransactionOptions (txOpts * spannerpb.TransactionOptions ) spannerdriver.ReadWriteTransactionOptions {
178+ func convertToReadWriteTransactionOptions (txOpts * spannerpb.TransactionOptions ) * spannerdriver.ReadWriteTransactionOptions {
166179 readLockMode := spannerpb .TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED
167180 if txOpts .GetReadWrite () != nil {
168181 readLockMode = txOpts .GetReadWrite ().GetReadLockMode ()
169182 }
170- return spannerdriver.ReadWriteTransactionOptions {
183+ return & spannerdriver.ReadWriteTransactionOptions {
171184 TransactionOptions : spanner.TransactionOptions {
172185 IsolationLevel : txOpts .GetIsolationLevel (),
173186 ReadLockMode : readLockMode ,
0 commit comments