@@ -129,6 +129,80 @@ func TestExplicitBeginTx(t *testing.T) {
129129 }
130130}
131131
132+ func TestExecuteBegin (t * testing.T ) {
133+ t .Parallel ()
134+
135+ db , server , teardown := setupTestDBConnection (t )
136+ defer teardown ()
137+ ctx := context .Background ()
138+
139+ for _ , end := range []string {"rollback" , "commit" } {
140+ c , err := db .Conn (ctx )
141+ if err != nil {
142+ t .Fatal (err )
143+ }
144+ if _ , err := c .ExecContext (ctx , "begin transaction" ); err != nil {
145+ t .Fatal (err )
146+ }
147+ if _ , err := c .ExecContext (ctx , testutil .UpdateBarSetFoo ); err != nil {
148+ t .Fatal (err )
149+ }
150+ if _ , err := c .ExecContext (ctx , end ); err != nil {
151+ t .Fatal (err )
152+ }
153+ if err := c .Close (); err != nil {
154+ t .Fatal (err )
155+ }
156+
157+ requests := drainRequestsFromServer (server .TestSpanner )
158+ beginRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.BeginTransactionRequest {}))
159+ if g , w := len (beginRequests ), 0 ; g != w {
160+ t .Fatalf ("begin requests count mismatch\n Got: %v\n Want: %v" , g , w )
161+ }
162+ executeRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.ExecuteSqlRequest {}))
163+ if g , w := len (executeRequests ), 1 ; g != w {
164+ t .Fatalf ("execute requests count mismatch\n Got: %v\n Want: %v" , g , w )
165+ }
166+ request := executeRequests [0 ].(* spannerpb.ExecuteSqlRequest )
167+ if request .GetTransaction () == nil || request .GetTransaction ().GetBegin () == nil {
168+ t .Fatal ("missing begin transaction on ExecuteSqlRequest" )
169+ }
170+ commitRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.CommitRequest {}))
171+ rollbackRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.RollbackRequest {}))
172+ if end == "commit" {
173+ if g , w := len (commitRequests ), 1 ; g != w {
174+ t .Fatalf ("commit requests count mismatch\n Got: %v\n Want: %v" , g , w )
175+ }
176+ } else if end == "rollback" {
177+ if g , w := len (rollbackRequests ), 1 ; g != w {
178+ t .Fatalf ("rollback requests count mismatch\n Got: %v\n Want: %v" , g , w )
179+ }
180+ }
181+ }
182+ }
183+
184+ func TestEndTransactionWithoutBegin (t * testing.T ) {
185+ t .Parallel ()
186+
187+ db , _ , teardown := setupTestDBConnection (t )
188+ defer teardown ()
189+ ctx := context .Background ()
190+
191+ for _ , end := range []string {"rollback" , "commit" } {
192+ c , err := db .Conn (ctx )
193+ if err != nil {
194+ t .Fatal (err )
195+ }
196+ _ , err = c .ExecContext (ctx , end )
197+ if g , w := spanner .ErrCode (err ), codes .FailedPrecondition ; g != w {
198+ t .Fatalf ("error code mismatch\n Got: %v\n Want: %v" , g , w )
199+ }
200+ if err := c .Close (); err != nil {
201+ t .Fatal (err )
202+ }
203+ }
204+ }
205+
132206func TestBeginTxWithIsolationLevel (t * testing.T ) {
133207 t .Parallel ()
134208
0 commit comments