Skip to content

Commit e88c303

Browse files
authored
feat: support BEGIN, COMMIT and ROLLBACK statements (#520)
Add support for BEGIN, COMMIT and ROLLBACK SQL statements.
1 parent 262b143 commit e88c303

File tree

9 files changed

+508
-30
lines changed

9 files changed

+508
-30
lines changed

conn.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,25 @@ func (c *conn) inReadWriteTransaction() bool {
12571257
return false
12581258
}
12591259

1260+
func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) {
1261+
if !c.inTransaction() {
1262+
return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
1263+
}
1264+
// TODO: Pass in context to the tx.Commit() function.
1265+
if err := c.tx.Commit(); err != nil {
1266+
return nil, err
1267+
}
1268+
return c.CommitResponse()
1269+
}
1270+
1271+
func (c *conn) rollback(ctx context.Context) error {
1272+
if !c.inTransaction() {
1273+
return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
1274+
}
1275+
// TODO: Pass in context to the tx.Rollback() function.
1276+
return c.tx.Rollback()
1277+
}
1278+
12601279
func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
12611280
return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions)
12621281
}

conn_with_mockserver_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,80 @@ func TestExplicitBeginTx(t *testing.T) {
128128
}
129129
}
130130

131+
func TestExecuteBegin(t *testing.T) {
132+
t.Parallel()
133+
134+
db, server, teardown := setupTestDBConnection(t)
135+
defer teardown()
136+
ctx := context.Background()
137+
138+
for _, end := range []string{"rollback", "commit"} {
139+
c, err := db.Conn(ctx)
140+
if err != nil {
141+
t.Fatal(err)
142+
}
143+
if _, err := c.ExecContext(ctx, "begin transaction"); err != nil {
144+
t.Fatal(err)
145+
}
146+
if _, err := c.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil {
147+
t.Fatal(err)
148+
}
149+
if _, err := c.ExecContext(ctx, end); err != nil {
150+
t.Fatal(err)
151+
}
152+
if err := c.Close(); err != nil {
153+
t.Fatal(err)
154+
}
155+
156+
requests := drainRequestsFromServer(server.TestSpanner)
157+
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
158+
if g, w := len(beginRequests), 0; g != w {
159+
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
160+
}
161+
executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
162+
if g, w := len(executeRequests), 1; g != w {
163+
t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w)
164+
}
165+
request := executeRequests[0].(*spannerpb.ExecuteSqlRequest)
166+
if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil {
167+
t.Fatal("missing begin transaction on ExecuteSqlRequest")
168+
}
169+
commitRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{}))
170+
rollbackRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{}))
171+
if end == "commit" {
172+
if g, w := len(commitRequests), 1; g != w {
173+
t.Fatalf("commit requests count mismatch\n Got: %v\nWant: %v", g, w)
174+
}
175+
} else if end == "rollback" {
176+
if g, w := len(rollbackRequests), 1; g != w {
177+
t.Fatalf("rollback requests count mismatch\n Got: %v\nWant: %v", g, w)
178+
}
179+
}
180+
}
181+
}
182+
183+
func TestEndTransactionWithoutBegin(t *testing.T) {
184+
t.Parallel()
185+
186+
db, _, teardown := setupTestDBConnection(t)
187+
defer teardown()
188+
ctx := context.Background()
189+
190+
for _, end := range []string{"rollback", "commit"} {
191+
c, err := db.Conn(ctx)
192+
if err != nil {
193+
t.Fatal(err)
194+
}
195+
_, err = c.ExecContext(ctx, end)
196+
if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w {
197+
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
198+
}
199+
if err := c.Close(); err != nil {
200+
t.Fatal(err)
201+
}
202+
}
203+
}
204+
131205
func TestBeginTxWithIsolationLevel(t *testing.T) {
132206
t.Parallel()
133207

parser/simple_parser.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,19 +282,26 @@ func (p *simpleParser) eatLiteral() (Literal, error) {
282282
func (p *simpleParser) eatKeywords(keywords []string) bool {
283283
startPos := p.pos
284284
for _, keyword := range keywords {
285-
if _, ok := p.eatKeyword(keyword); !ok {
285+
if !p.eatKeyword(keyword) {
286286
p.pos = startPos
287287
return false
288288
}
289289
}
290290
return true
291291
}
292292

293-
// eatKeyword eats the given keyword at the current position of the parser if it exists.
293+
// eatKeyword eats the given keyword at the current position of the parser if it exists
294+
// and returns true if the keyword was found. Otherwise, it returns false.
295+
func (p *simpleParser) eatKeyword(keyword string) bool {
296+
_, ok := p.eatAndReturnKeyword(keyword)
297+
return ok
298+
}
299+
300+
// eatAndReturnKeyword eats the given keyword at the current position of the parser if it exists.
294301
//
295302
// Returns the actual keyword that was read and true if the keyword is found, and updates the position of the parser.
296303
// Returns an empty string and false without updating the position of the parser if the keyword was not found.
297-
func (p *simpleParser) eatKeyword(keyword string) (string, bool) {
304+
func (p *simpleParser) eatAndReturnKeyword(keyword string) (string, bool) {
298305
startPos := p.pos
299306
found := p.readKeyword()
300307
if !strings.EqualFold(found, keyword) {

parser/statement_parser.go

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,17 @@ var updateStatements = map[string]bool{"UPDATE": true}
3535
var deleteStatements = map[string]bool{"DELETE": true}
3636
var dmlStatements = union(insertStatements, union(updateStatements, deleteStatements))
3737
var clientSideKeywords = map[string]bool{
38-
"SHOW": true,
39-
"SET": true,
40-
"RESET": true,
41-
"START": true,
42-
"RUN": true,
43-
"ABORT": true,
44-
"CREATE": true, // CREATE DATABASE is handled as a client-side statement
45-
"DROP": true, // DROP DATABASE is handled as a client-side statement
38+
"SHOW": true,
39+
"SET": true,
40+
"RESET": true,
41+
"START": true,
42+
"RUN": true,
43+
"ABORT": true,
44+
"BEGIN": true,
45+
"COMMIT": true,
46+
"ROLLBACK": true,
47+
"CREATE": true, // CREATE DATABASE is handled as a client-side statement
48+
"DROP": true, // DROP DATABASE is handled as a client-side statement
4649
}
4750
var createStatements = map[string]bool{"CREATE": true}
4851
var dropStatements = map[string]bool{"DROP": true}
@@ -52,6 +55,9 @@ var resetStatements = map[string]bool{"RESET": true}
5255
var startStatements = map[string]bool{"START": true}
5356
var runStatements = map[string]bool{"RUN": true}
5457
var abortStatements = map[string]bool{"ABORT": true}
58+
var beginStatements = map[string]bool{"BEGIN": true}
59+
var commitStatements = map[string]bool{"COMMIT": true}
60+
var rollbackStatements = map[string]bool{"ROLLBACK": true}
5561

5662
func union(m1 map[string]bool, m2 map[string]bool) map[string]bool {
5763
res := make(map[string]bool, len(m1)+len(m2))
@@ -660,6 +666,18 @@ func isAbortStatementKeyword(keyword string) bool {
660666
return isStatementKeyword(keyword, abortStatements)
661667
}
662668

669+
func isBeginStatementKeyword(keyword string) bool {
670+
return isStatementKeyword(keyword, beginStatements)
671+
}
672+
673+
func isCommitStatementKeyword(keyword string) bool {
674+
return isStatementKeyword(keyword, commitStatements)
675+
}
676+
677+
func isRollbackStatementKeyword(keyword string) bool {
678+
return isStatementKeyword(keyword, rollbackStatements)
679+
}
680+
663681
func isStatementKeyword(keyword string, keywords map[string]bool) bool {
664682
_, ok := keywords[keyword]
665683
return ok

parser/statement_parser_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2467,7 +2467,7 @@ func TestEatKeyword(t *testing.T) {
24672467
for _, test := range tests {
24682468
sp := &simpleParser{sql: []byte(test.input), statementParser: parser}
24692469
startPos := sp.pos
2470-
keyword, ok := sp.eatKeyword(test.keyword)
2470+
keyword, ok := sp.eatAndReturnKeyword(test.keyword)
24712471
if g, w := ok, test.wantOk; g != w {
24722472
t.Errorf("found mismatch\n Got: %v\nWant: %v", g, w)
24732473
}

0 commit comments

Comments
 (0)