@@ -128,6 +128,80 @@ func TestExplicitBeginTx(t *testing.T) {
128128 }
129129}
130130
131+ func TestExecuteBegin (t * testing.T ) {
132+ t .Parallel ()
133+
134+ db , server , teardown := setupTestDBConnection (t )
135+ defer teardown ()
136+ ctx := context .Background ()
137+
138+ for _ , end := range []string {"rollback" , "commit" } {
139+ c , err := db .Conn (ctx )
140+ if err != nil {
141+ t .Fatal (err )
142+ }
143+ if _ , err := c .ExecContext (ctx , "begin transaction" ); err != nil {
144+ t .Fatal (err )
145+ }
146+ if _ , err := c .ExecContext (ctx , testutil .UpdateBarSetFoo ); err != nil {
147+ t .Fatal (err )
148+ }
149+ if _ , err := c .ExecContext (ctx , end ); err != nil {
150+ t .Fatal (err )
151+ }
152+ if err := c .Close (); err != nil {
153+ t .Fatal (err )
154+ }
155+
156+ requests := drainRequestsFromServer (server .TestSpanner )
157+ beginRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.BeginTransactionRequest {}))
158+ if g , w := len (beginRequests ), 0 ; g != w {
159+ t .Fatalf ("begin requests count mismatch\n Got: %v\n Want: %v" , g , w )
160+ }
161+ executeRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.ExecuteSqlRequest {}))
162+ if g , w := len (executeRequests ), 1 ; g != w {
163+ t .Fatalf ("execute requests count mismatch\n Got: %v\n Want: %v" , g , w )
164+ }
165+ request := executeRequests [0 ].(* spannerpb.ExecuteSqlRequest )
166+ if request .GetTransaction () == nil || request .GetTransaction ().GetBegin () == nil {
167+ t .Fatal ("missing begin transaction on ExecuteSqlRequest" )
168+ }
169+ commitRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.CommitRequest {}))
170+ rollbackRequests := requestsOfType (requests , reflect .TypeOf (& spannerpb.RollbackRequest {}))
171+ if end == "commit" {
172+ if g , w := len (commitRequests ), 1 ; g != w {
173+ t .Fatalf ("commit requests count mismatch\n Got: %v\n Want: %v" , g , w )
174+ }
175+ } else if end == "rollback" {
176+ if g , w := len (rollbackRequests ), 1 ; g != w {
177+ t .Fatalf ("rollback requests count mismatch\n Got: %v\n Want: %v" , g , w )
178+ }
179+ }
180+ }
181+ }
182+
183+ func TestEndTransactionWithoutBegin (t * testing.T ) {
184+ t .Parallel ()
185+
186+ db , _ , teardown := setupTestDBConnection (t )
187+ defer teardown ()
188+ ctx := context .Background ()
189+
190+ for _ , end := range []string {"rollback" , "commit" } {
191+ c , err := db .Conn (ctx )
192+ if err != nil {
193+ t .Fatal (err )
194+ }
195+ _ , err = c .ExecContext (ctx , end )
196+ if g , w := spanner .ErrCode (err ), codes .FailedPrecondition ; g != w {
197+ t .Fatalf ("error code mismatch\n Got: %v\n Want: %v" , g , w )
198+ }
199+ if err := c .Close (); err != nil {
200+ t .Fatal (err )
201+ }
202+ }
203+ }
204+
131205func TestBeginTxWithIsolationLevel (t * testing.T ) {
132206 t .Parallel ()
133207
0 commit comments