Skip to content

Commit ec75601

Browse files
committed
feat: support BEGIN, COMMIT and ROLLBACK statements
Add support for BEGIN, COMMIT and ROLLBACK SQL statements.
1 parent 31d67ce commit ec75601

File tree

9 files changed

+509
-30
lines changed

9 files changed

+509
-30
lines changed

conn.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,25 @@ func (c *conn) inReadWriteTransaction() bool {
12531253
return false
12541254
}
12551255

1256+
func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) {
1257+
if !c.inTransaction() {
1258+
return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
1259+
}
1260+
// TODO: Pass in context to the tx.Commit() function.
1261+
if err := c.tx.Commit(); err != nil {
1262+
return nil, err
1263+
}
1264+
return c.CommitResponse()
1265+
}
1266+
1267+
func (c *conn) rollback(ctx context.Context) error {
1268+
if !c.inTransaction() {
1269+
return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
1270+
}
1271+
// TODO: Pass in context to the tx.Rollback() function.
1272+
return c.tx.Rollback()
1273+
}
1274+
12561275
func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
12571276
return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions)
12581277
}

conn_with_mockserver_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"cloud.google.com/go/spanner/apiv1/spannerpb"
2828
"github.com/googleapis/go-sql-spanner/connectionstate"
2929
"github.com/googleapis/go-sql-spanner/testutil"
30+
"google.golang.org/grpc/codes"
3031
"google.golang.org/protobuf/proto"
3132
"google.golang.org/protobuf/types/known/anypb"
3233
"google.golang.org/protobuf/types/known/emptypb"
@@ -108,6 +109,80 @@ func TestExplicitBeginTx(t *testing.T) {
108109
}
109110
}
110111

112+
func TestExecuteBegin(t *testing.T) {
113+
t.Parallel()
114+
115+
db, server, teardown := setupTestDBConnection(t)
116+
defer teardown()
117+
ctx := context.Background()
118+
119+
for _, end := range []string{"rollback", "commit"} {
120+
c, err := db.Conn(ctx)
121+
if err != nil {
122+
t.Fatal(err)
123+
}
124+
if _, err := c.ExecContext(ctx, "begin transaction"); err != nil {
125+
t.Fatal(err)
126+
}
127+
if _, err := c.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil {
128+
t.Fatal(err)
129+
}
130+
if _, err := c.ExecContext(ctx, end); err != nil {
131+
t.Fatal(err)
132+
}
133+
if err := c.Close(); err != nil {
134+
t.Fatal(err)
135+
}
136+
137+
requests := drainRequestsFromServer(server.TestSpanner)
138+
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
139+
if g, w := len(beginRequests), 0; g != w {
140+
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
141+
}
142+
executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
143+
if g, w := len(executeRequests), 1; g != w {
144+
t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w)
145+
}
146+
request := executeRequests[0].(*spannerpb.ExecuteSqlRequest)
147+
if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil {
148+
t.Fatal("missing begin transaction on ExecuteSqlRequest")
149+
}
150+
commitRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{}))
151+
rollbackRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{}))
152+
if end == "commit" {
153+
if g, w := len(commitRequests), 1; g != w {
154+
t.Fatalf("commit requests count mismatch\n Got: %v\nWant: %v", g, w)
155+
}
156+
} else if end == "rollback" {
157+
if g, w := len(rollbackRequests), 1; g != w {
158+
t.Fatalf("rollback requests count mismatch\n Got: %v\nWant: %v", g, w)
159+
}
160+
}
161+
}
162+
}
163+
164+
func TestEndTransactionWithoutBegin(t *testing.T) {
165+
t.Parallel()
166+
167+
db, _, teardown := setupTestDBConnection(t)
168+
defer teardown()
169+
ctx := context.Background()
170+
171+
for _, end := range []string{"rollback", "commit"} {
172+
c, err := db.Conn(ctx)
173+
if err != nil {
174+
t.Fatal(err)
175+
}
176+
_, err = c.ExecContext(ctx, end)
177+
if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w {
178+
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
179+
}
180+
if err := c.Close(); err != nil {
181+
t.Fatal(err)
182+
}
183+
}
184+
}
185+
111186
func TestBeginTxWithIsolationLevel(t *testing.T) {
112187
t.Parallel()
113188

parser/simple_parser.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,19 +282,26 @@ func (p *simpleParser) eatLiteral() (Literal, error) {
282282
func (p *simpleParser) eatKeywords(keywords []string) bool {
283283
startPos := p.pos
284284
for _, keyword := range keywords {
285-
if _, ok := p.eatKeyword(keyword); !ok {
285+
if !p.eatKeyword(keyword) {
286286
p.pos = startPos
287287
return false
288288
}
289289
}
290290
return true
291291
}
292292

293-
// eatKeyword eats the given keyword at the current position of the parser if it exists.
293+
// eatKeyword eats the given keyword at the current position of the parser if it exists
294+
// and returns true if the keyword was found. Otherwise, it returns false.
295+
func (p *simpleParser) eatKeyword(keyword string) bool {
296+
_, ok := p.eatAndReturnKeyword(keyword)
297+
return ok
298+
}
299+
300+
// eatAndReturnKeyword eats the given keyword at the current position of the parser if it exists.
294301
//
295302
// Returns the actual keyword that was read and true if the keyword is found, and updates the position of the parser.
296303
// Returns an empty string and false without updating the position of the parser if the keyword was not found.
297-
func (p *simpleParser) eatKeyword(keyword string) (string, bool) {
304+
func (p *simpleParser) eatAndReturnKeyword(keyword string) (string, bool) {
298305
startPos := p.pos
299306
found := p.readKeyword()
300307
if !strings.EqualFold(found, keyword) {

parser/statement_parser.go

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,17 @@ var updateStatements = map[string]bool{"UPDATE": true}
3535
var deleteStatements = map[string]bool{"DELETE": true}
3636
var dmlStatements = union(insertStatements, union(updateStatements, deleteStatements))
3737
var clientSideKeywords = map[string]bool{
38-
"SHOW": true,
39-
"SET": true,
40-
"RESET": true,
41-
"START": true,
42-
"RUN": true,
43-
"ABORT": true,
44-
"CREATE": true, // CREATE DATABASE is handled as a client-side statement
45-
"DROP": true, // DROP DATABASE is handled as a client-side statement
38+
"SHOW": true,
39+
"SET": true,
40+
"RESET": true,
41+
"START": true,
42+
"RUN": true,
43+
"ABORT": true,
44+
"BEGIN": true,
45+
"COMMIT": true,
46+
"ROLLBACK": true,
47+
"CREATE": true, // CREATE DATABASE is handled as a client-side statement
48+
"DROP": true, // DROP DATABASE is handled as a client-side statement
4649
}
4750
var createStatements = map[string]bool{"CREATE": true}
4851
var dropStatements = map[string]bool{"DROP": true}
@@ -52,6 +55,9 @@ var resetStatements = map[string]bool{"RESET": true}
5255
var startStatements = map[string]bool{"START": true}
5356
var runStatements = map[string]bool{"RUN": true}
5457
var abortStatements = map[string]bool{"ABORT": true}
58+
var beginStatements = map[string]bool{"BEGIN": true}
59+
var commitStatements = map[string]bool{"COMMIT": true}
60+
var rollbackStatements = map[string]bool{"ROLLBACK": true}
5561

5662
func union(m1 map[string]bool, m2 map[string]bool) map[string]bool {
5763
res := make(map[string]bool, len(m1)+len(m2))
@@ -660,6 +666,18 @@ func isAbortStatementKeyword(keyword string) bool {
660666
return isStatementKeyword(keyword, abortStatements)
661667
}
662668

669+
func isBeginStatementKeyword(keyword string) bool {
670+
return isStatementKeyword(keyword, beginStatements)
671+
}
672+
673+
func isCommitStatementKeyword(keyword string) bool {
674+
return isStatementKeyword(keyword, commitStatements)
675+
}
676+
677+
func isRollbackStatementKeyword(keyword string) bool {
678+
return isStatementKeyword(keyword, rollbackStatements)
679+
}
680+
663681
func isStatementKeyword(keyword string, keywords map[string]bool) bool {
664682
_, ok := keywords[keyword]
665683
return ok

parser/statement_parser_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2467,7 +2467,7 @@ func TestEatKeyword(t *testing.T) {
24672467
for _, test := range tests {
24682468
sp := &simpleParser{sql: []byte(test.input), statementParser: parser}
24692469
startPos := sp.pos
2470-
keyword, ok := sp.eatKeyword(test.keyword)
2470+
keyword, ok := sp.eatAndReturnKeyword(test.keyword)
24712471
if g, w := ok, test.wantOk; g != w {
24722472
t.Errorf("found mismatch\n Got: %v\nWant: %v", g, w)
24732473
}

0 commit comments

Comments
 (0)