@@ -27,6 +27,7 @@ import (
2727 "cloud.google.com/go/spanner/apiv1/spannerpb"
2828 "github.com/googleapis/go-sql-spanner/connectionstate"
2929 "github.com/googleapis/go-sql-spanner/testutil"
30+ "google.golang.org/grpc/codes"
3031 "google.golang.org/protobuf/proto"
3132 "google.golang.org/protobuf/types/known/anypb"
3233 "google.golang.org/protobuf/types/known/emptypb"
@@ -108,6 +109,80 @@ func TestExplicitBeginTx(t *testing.T) {
108109 }
109110}
110111
112+ func TestExecuteBegin (t * testing.T ) {
113+ t .Parallel ()
114+
115+ db , server , teardown := setupTestDBConnection (t )
116+ defer teardown ()
117+ ctx := context .Background ()
118+
119+ for _ , end := range []string {"rollback" , "commit" } {
120+ c , err := db .Conn (ctx )
121+ if err != nil {
122+ t .Fatal (err )
123+ }
124+ if _ , err := c .ExecContext (ctx , "begin transaction" ); err != nil {
125+ t .Fatal (err )
126+ }
127+ if _ , err := c .ExecContext (ctx , testutil .UpdateBarSetFoo ); err != nil {
128+ t .Fatal (err )
129+ }
130+ if _ , err := c .ExecContext (ctx , end ); err != nil {
131+ t .Fatal (err )
132+ }
133+ if err := c .Close (); err != nil {
134+ t .Fatal (err )
135+ }
136+
137+ requests := drainRequestsFromServer (server .TestSpanner )
138+ beginRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.BeginTransactionRequest {}))
139+ if g , w := len (beginRequests ), 0 ; g != w {
140+ t .Fatalf ("begin requests count mismatch\n Got: %v\n Want: %v" , g , w )
141+ }
142+ executeRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.ExecuteSqlRequest {}))
143+ if g , w := len (executeRequests ), 1 ; g != w {
144+ t .Fatalf ("execute requests count mismatch\n Got: %v\n Want: %v" , g , w )
145+ }
146+ request := executeRequests [0 ].(* spannerpb.ExecuteSqlRequest )
147+ if request .GetTransaction () == nil || request .GetTransaction ().GetBegin () == nil {
148+ t .Fatal ("missing begin transaction on ExecuteSqlRequest" )
149+ }
150+ commitRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.CommitRequest {}))
151+ rollbackRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.RollbackRequest {}))
152+ if end == "commit" {
153+ if g , w := len (commitRequests ), 1 ; g != w {
154+ t .Fatalf ("commit requests count mismatch\n Got: %v\n Want: %v" , g , w )
155+ }
156+ } else if end == "rollback" {
157+ if g , w := len (rollbackRequests ), 1 ; g != w {
158+ t .Fatalf ("rollback requests count mismatch\n Got: %v\n Want: %v" , g , w )
159+ }
160+ }
161+ }
162+ }
163+
164+ func TestEndTransactionWithoutBegin (t * testing.T ) {
165+ t .Parallel ()
166+
167+ db , _ , teardown := setupTestDBConnection (t )
168+ defer teardown ()
169+ ctx := context .Background ()
170+
171+ for _ , end := range []string {"rollback" , "commit" } {
172+ c , err := db .Conn (ctx )
173+ if err != nil {
174+ t .Fatal (err )
175+ }
176+ _ , err = c .ExecContext (ctx , end )
177+ if g , w := spanner .ErrCode (err ), codes .FailedPrecondition ; g != w {
178+ t .Fatalf ("error code mismatch\n Got: %v\n Want: %v" , g , w )
179+ }
180+ if err := c .Close (); err != nil {
181+ t .Fatal (err )
182+ }
183+ }
184+ }
185+
111186func TestBeginTxWithIsolationLevel (t * testing.T ) {
112187 t .Parallel ()
113188
0 commit comments