Skip to content

Commit 2bf6e3e

Browse files
committed
Merge branch 'tx-statements' into spanner-lib
2 parents 23f6ea6 + ec75601 commit 2bf6e3e

File tree

9 files changed

+508
-30
lines changed

9 files changed

+508
-30
lines changed

conn.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,6 +1274,25 @@ func (c *conn) inReadWriteTransaction() bool {
12741274
return false
12751275
}
12761276

1277+
func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) {
1278+
if !c.inTransaction() {
1279+
return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
1280+
}
1281+
// TODO: Pass in context to the tx.Commit() function.
1282+
if err := c.tx.Commit(); err != nil {
1283+
return nil, err
1284+
}
1285+
return c.CommitResponse()
1286+
}
1287+
1288+
func (c *conn) rollback(ctx context.Context) error {
1289+
if !c.inTransaction() {
1290+
return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
1291+
}
1292+
// TODO: Pass in context to the tx.Rollback() function.
1293+
return c.tx.Rollback()
1294+
}
1295+
12771296
func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
12781297
if options.TimestampBound != nil {
12791298
tb = *options.TimestampBound

conn_with_mockserver_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,80 @@ func TestExplicitBeginTx(t *testing.T) {
129129
}
130130
}
131131

132+
func TestExecuteBegin(t *testing.T) {
133+
t.Parallel()
134+
135+
db, server, teardown := setupTestDBConnection(t)
136+
defer teardown()
137+
ctx := context.Background()
138+
139+
for _, end := range []string{"rollback", "commit"} {
140+
c, err := db.Conn(ctx)
141+
if err != nil {
142+
t.Fatal(err)
143+
}
144+
if _, err := c.ExecContext(ctx, "begin transaction"); err != nil {
145+
t.Fatal(err)
146+
}
147+
if _, err := c.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil {
148+
t.Fatal(err)
149+
}
150+
if _, err := c.ExecContext(ctx, end); err != nil {
151+
t.Fatal(err)
152+
}
153+
if err := c.Close(); err != nil {
154+
t.Fatal(err)
155+
}
156+
157+
requests := drainRequestsFromServer(server.TestSpanner)
158+
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
159+
if g, w := len(beginRequests), 0; g != w {
160+
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
161+
}
162+
executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
163+
if g, w := len(executeRequests), 1; g != w {
164+
t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w)
165+
}
166+
request := executeRequests[0].(*spannerpb.ExecuteSqlRequest)
167+
if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil {
168+
t.Fatal("missing begin transaction on ExecuteSqlRequest")
169+
}
170+
commitRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{}))
171+
rollbackRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{}))
172+
if end == "commit" {
173+
if g, w := len(commitRequests), 1; g != w {
174+
t.Fatalf("commit requests count mismatch\n Got: %v\nWant: %v", g, w)
175+
}
176+
} else if end == "rollback" {
177+
if g, w := len(rollbackRequests), 1; g != w {
178+
t.Fatalf("rollback requests count mismatch\n Got: %v\nWant: %v", g, w)
179+
}
180+
}
181+
}
182+
}
183+
184+
func TestEndTransactionWithoutBegin(t *testing.T) {
185+
t.Parallel()
186+
187+
db, _, teardown := setupTestDBConnection(t)
188+
defer teardown()
189+
ctx := context.Background()
190+
191+
for _, end := range []string{"rollback", "commit"} {
192+
c, err := db.Conn(ctx)
193+
if err != nil {
194+
t.Fatal(err)
195+
}
196+
_, err = c.ExecContext(ctx, end)
197+
if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w {
198+
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
199+
}
200+
if err := c.Close(); err != nil {
201+
t.Fatal(err)
202+
}
203+
}
204+
}
205+
132206
func TestBeginTxWithIsolationLevel(t *testing.T) {
133207
t.Parallel()
134208

parser/simple_parser.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,19 +282,26 @@ func (p *simpleParser) eatLiteral() (Literal, error) {
282282
func (p *simpleParser) eatKeywords(keywords []string) bool {
283283
startPos := p.pos
284284
for _, keyword := range keywords {
285-
if _, ok := p.eatKeyword(keyword); !ok {
285+
if !p.eatKeyword(keyword) {
286286
p.pos = startPos
287287
return false
288288
}
289289
}
290290
return true
291291
}
292292

293-
// eatKeyword eats the given keyword at the current position of the parser if it exists.
293+
// eatKeyword eats the given keyword at the current position of the parser if it exists
294+
// and returns true if the keyword was found. Otherwise, it returns false.
295+
func (p *simpleParser) eatKeyword(keyword string) bool {
296+
_, ok := p.eatAndReturnKeyword(keyword)
297+
return ok
298+
}
299+
300+
// eatAndReturnKeyword eats the given keyword at the current position of the parser if it exists.
294301
//
295302
// Returns the actual keyword that was read and true if the keyword is found, and updates the position of the parser.
296303
// Returns an empty string and false without updating the position of the parser if the keyword was not found.
297-
func (p *simpleParser) eatKeyword(keyword string) (string, bool) {
304+
func (p *simpleParser) eatAndReturnKeyword(keyword string) (string, bool) {
298305
startPos := p.pos
299306
found := p.readKeyword()
300307
if !strings.EqualFold(found, keyword) {

parser/statement_parser.go

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,17 @@ var updateStatements = map[string]bool{"UPDATE": true}
3535
var deleteStatements = map[string]bool{"DELETE": true}
3636
var dmlStatements = union(insertStatements, union(updateStatements, deleteStatements))
3737
var clientSideKeywords = map[string]bool{
38-
"SHOW": true,
39-
"SET": true,
40-
"RESET": true,
41-
"START": true,
42-
"RUN": true,
43-
"ABORT": true,
44-
"CREATE": true, // CREATE DATABASE is handled as a client-side statement
45-
"DROP": true, // DROP DATABASE is handled as a client-side statement
38+
"SHOW": true,
39+
"SET": true,
40+
"RESET": true,
41+
"START": true,
42+
"RUN": true,
43+
"ABORT": true,
44+
"BEGIN": true,
45+
"COMMIT": true,
46+
"ROLLBACK": true,
47+
"CREATE": true, // CREATE DATABASE is handled as a client-side statement
48+
"DROP": true, // DROP DATABASE is handled as a client-side statement
4649
}
4750
var createStatements = map[string]bool{"CREATE": true}
4851
var dropStatements = map[string]bool{"DROP": true}
@@ -52,6 +55,9 @@ var resetStatements = map[string]bool{"RESET": true}
5255
var startStatements = map[string]bool{"START": true}
5356
var runStatements = map[string]bool{"RUN": true}
5457
var abortStatements = map[string]bool{"ABORT": true}
58+
var beginStatements = map[string]bool{"BEGIN": true}
59+
var commitStatements = map[string]bool{"COMMIT": true}
60+
var rollbackStatements = map[string]bool{"ROLLBACK": true}
5561

5662
func union(m1 map[string]bool, m2 map[string]bool) map[string]bool {
5763
res := make(map[string]bool, len(m1)+len(m2))
@@ -660,6 +666,18 @@ func isAbortStatementKeyword(keyword string) bool {
660666
return isStatementKeyword(keyword, abortStatements)
661667
}
662668

669+
func isBeginStatementKeyword(keyword string) bool {
670+
return isStatementKeyword(keyword, beginStatements)
671+
}
672+
673+
func isCommitStatementKeyword(keyword string) bool {
674+
return isStatementKeyword(keyword, commitStatements)
675+
}
676+
677+
func isRollbackStatementKeyword(keyword string) bool {
678+
return isStatementKeyword(keyword, rollbackStatements)
679+
}
680+
663681
func isStatementKeyword(keyword string, keywords map[string]bool) bool {
664682
_, ok := keywords[keyword]
665683
return ok

parser/statement_parser_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2467,7 +2467,7 @@ func TestEatKeyword(t *testing.T) {
24672467
for _, test := range tests {
24682468
sp := &simpleParser{sql: []byte(test.input), statementParser: parser}
24692469
startPos := sp.pos
2470-
keyword, ok := sp.eatKeyword(test.keyword)
2470+
keyword, ok := sp.eatAndReturnKeyword(test.keyword)
24712471
if g, w := ok, test.wantOk; g != w {
24722472
t.Errorf("found mismatch\n Got: %v\nWant: %v", g, w)
24732473
}

0 commit comments

Comments
 (0)