Skip to content

Commit 49e945e

Browse files
authored
feat: support transaction options in BEGIN statements (#550)
* feat: parse SET TRANSACTION statements Parse SET TRANSACTION statements and translate these to SET LOCAL statements. SET TRANSACTION may only be executed in a transaction block, and can only be used for a specific, limited set of connection properties. The syntax is specified by the SQL standard and PostgreSQL. See also https://www.postgresql.org/docs/current/sql-set-transaction.html This change only adds partial support. The following features will be added in future changes: 1. SET TRANSACTION READ {WRITE | ONLY} is not picked up by the driver, as the type of transaction is set directly when BeginTx is called. A refactor of this transaction handling is needed to be able to pick up SET TRANSACTION READ ONLY / SET TRANSACTION READ WRITE statements that are executed after BeginTx has been called. 2. PostgreSQL allows multiple transaction modes to be set in a single SET TRANSACTION statement. E.g. the following is allowed: SET TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE The current implementation only supports one transaction mode per SET statement. * feat: support multiple transaction options in one statement * feat: support transaction options in BEGIN statements Adds support for including transaction options in BEGIN statements, like: ```sql BEGIN READ ONLY; BEGIN READ WRITE; BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ; BEGIN READ WRITE, ISOLATION LEVEL SERIALIZABLE; ``` * chore: re-trigger checks
1 parent 6a396d3 commit 49e945e

File tree

4 files changed

+264
-28
lines changed

4 files changed

+264
-28
lines changed

parser/statements.go

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -230,77 +230,92 @@ func (s *ParsedSetStatement) parseSetTransaction(sp *simpleParser, query string)
230230
s.IsLocal = true
231231
s.IsTransaction = true
232232

233+
var err error
234+
s.Identifiers, s.Literals, err = parseTransactionOptions(sp)
235+
if err != nil {
236+
return err
237+
}
238+
return nil
239+
}
240+
241+
func parseTransactionOptions(sp *simpleParser) ([]Identifier, []Literal, error) {
242+
identifiers := make([]Identifier, 0, 2)
243+
literals := make([]Literal, 0, 2)
244+
var err error
233245
for {
234246
if sp.peekKeyword("ISOLATION") {
235-
if err := s.parseSetTransactionIsolationLevel(sp, query); err != nil {
236-
return err
247+
identifiers, literals, err = parseTransactionIsolationLevel(sp, identifiers, literals)
248+
if err != nil {
249+
return nil, nil, err
237250
}
238251
} else if sp.peekKeyword("READ") {
239-
if err := s.parseSetTransactionMode(sp, query); err != nil {
240-
return err
252+
identifiers, literals, err = parseTransactionMode(sp, identifiers, literals)
253+
if err != nil {
254+
return nil, nil, err
241255
}
242256
} else if sp.statementParser.Dialect == databasepb.DatabaseDialect_POSTGRESQL && (sp.peekKeyword("DEFERRABLE") || sp.peekKeyword("NOT")) {
243257
// https://www.postgresql.org/docs/current/sql-set-transaction.html
244-
if err := s.parseSetTransactionDeferrable(sp, query); err != nil {
245-
return err
258+
identifiers, literals, err = parseTransactionDeferrable(sp, identifiers, literals)
259+
if err != nil {
260+
return nil, nil, err
246261
}
247262
} else {
248-
return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY")
263+
return nil, nil, status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY")
249264
}
250265
if !sp.hasMoreTokens() {
251-
return nil
266+
return identifiers, literals, nil
252267
}
253268
// Eat and ignore any commas separating the various options.
254269
sp.eatToken(',')
255270
}
256271
}
257272

258-
func (s *ParsedSetStatement) parseSetTransactionIsolationLevel(sp *simpleParser, query string) error {
273+
func parseTransactionIsolationLevel(sp *simpleParser, identifiers []Identifier, literals []Literal) ([]Identifier, []Literal, error) {
259274
if !sp.eatKeywords([]string{"ISOLATION", "LEVEL"}) {
260-
return status.Errorf(codes.InvalidArgument, "syntax error: expected ISOLATION LEVEL")
275+
return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected ISOLATION LEVEL")
261276
}
262277
var value Literal
263278
if sp.eatKeyword("SERIALIZABLE") {
264279
value = Literal{Value: "serializable"}
265280
} else if sp.eatKeywords([]string{"REPEATABLE", "READ"}) {
266281
value = Literal{Value: "repeatable_read"}
267282
} else {
268-
return status.Errorf(codes.InvalidArgument, "syntax error: expected SERIALIZABLE OR REPETABLE READ")
283+
return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected SERIALIZABLE OR REPETABLE READ")
269284
}
270285

271-
s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"isolation_level"}})
272-
s.Literals = append(s.Literals, value)
273-
return nil
286+
identifiers = append(identifiers, Identifier{Parts: []string{"isolation_level"}})
287+
literals = append(literals, value)
288+
return identifiers, literals, nil
274289
}
275290

276-
func (s *ParsedSetStatement) parseSetTransactionMode(sp *simpleParser, query string) error {
291+
func parseTransactionMode(sp *simpleParser, identifiers []Identifier, literals []Literal) ([]Identifier, []Literal, error) {
277292
readOnly := false
278293
if sp.eatKeywords([]string{"READ", "ONLY"}) {
279294
readOnly = true
280295
} else if sp.eatKeywords([]string{"READ", "WRITE"}) {
281296
readOnly = false
282297
} else {
283-
return status.Errorf(codes.InvalidArgument, "syntax error: expected READ ONLY or READ WRITE")
298+
return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected READ ONLY or READ WRITE")
284299
}
285300

286-
s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"transaction_read_only"}})
287-
s.Literals = append(s.Literals, Literal{Value: fmt.Sprintf("%v", readOnly)})
288-
return nil
301+
identifiers = append(identifiers, Identifier{Parts: []string{"transaction_read_only"}})
302+
literals = append(literals, Literal{Value: fmt.Sprintf("%v", readOnly)})
303+
return identifiers, literals, nil
289304
}
290305

291-
func (s *ParsedSetStatement) parseSetTransactionDeferrable(sp *simpleParser, query string) error {
306+
func parseTransactionDeferrable(sp *simpleParser, identifiers []Identifier, literals []Literal) ([]Identifier, []Literal, error) {
292307
deferrable := false
293308
if sp.eatKeywords([]string{"NOT", "DEFERRABLE"}) {
294309
deferrable = false
295310
} else if sp.eatKeyword("DEFERRABLE") {
296311
deferrable = true
297312
} else {
298-
return status.Errorf(codes.InvalidArgument, "syntax error: expected [NOT] DEFERRABLE")
313+
return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected [NOT] DEFERRABLE")
299314
}
300315

301-
s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"transaction_deferrable"}})
302-
s.Literals = append(s.Literals, Literal{Value: fmt.Sprintf("%v", deferrable)})
303-
return nil
316+
identifiers = append(identifiers, Identifier{Parts: []string{"transaction_deferrable"}})
317+
literals = append(literals, Literal{Value: fmt.Sprintf("%v", deferrable)})
318+
return identifiers, literals, nil
304319
}
305320

306321
// ParsedResetStatement is a statement of the form
@@ -496,6 +511,12 @@ func (s *ParsedAbortBatchStatement) parse(parser *StatementParser, query string)
496511

497512
type ParsedBeginStatement struct {
498513
query string
514+
// Identifiers contains the transaction properties that were included in the BEGIN statement. E.g. the statement
515+
// BEGIN TRANSACTION READ ONLY contains the transaction property 'transaction_read_only'.
516+
Identifiers []Identifier
517+
// Literals contains the transaction property values that were included in the BEGIN statement. E.g. the statement
518+
// BEGIN TRANSACTION READ ONLY contains the value 'true' for the property 'transaction_read_only'.
519+
Literals []Literal
499520
}
500521

501522
func (s *ParsedBeginStatement) Name() string {
@@ -508,7 +529,7 @@ func (s *ParsedBeginStatement) Query() string {
508529

509530
func (s *ParsedBeginStatement) parse(parser *StatementParser, query string) error {
510531
// Parse a statement of the form
511-
// GoogleSQL: BEGIN [TRANSACTION]
532+
// GoogleSQL: BEGIN [TRANSACTION] [READ WRITE | READ ONLY | ISOLATION LEVEL {SERIALIZABLE | READ COMMITTED}]
512533
// PostgreSQL: {START | BEGIN} [{TRANSACTION | WORK}] (https://www.postgresql.org/docs/current/sql-begin.html)
513534
// TODO: Support transaction modes in the BEGIN / START statement.
514535
sp := &simpleParser{sql: []byte(query), statementParser: parser}
@@ -531,8 +552,13 @@ func (s *ParsedBeginStatement) parse(parser *StatementParser, query string) erro
531552
}
532553

533554
if sp.hasMoreTokens() {
534-
return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql)
555+
var err error
556+
s.Identifiers, s.Literals, err = parseTransactionOptions(sp)
557+
if err != nil {
558+
return err
559+
}
535560
}
561+
536562
s.query = query
537563
return nil
538564
}

parser/statements_test.go

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,38 @@ func TestParseBeginStatementGoogleSQL(t *testing.T) {
431431
input: "begin transaction foo",
432432
wantErr: true,
433433
},
434+
{
435+
input: "begin read only",
436+
want: ParsedBeginStatement{
437+
query: "begin read only",
438+
Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}},
439+
Literals: []Literal{{Value: "true"}},
440+
},
441+
},
442+
{
443+
input: "begin read write",
444+
want: ParsedBeginStatement{
445+
query: "begin read write",
446+
Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}},
447+
Literals: []Literal{{Value: "false"}},
448+
},
449+
},
450+
{
451+
input: "begin transaction isolation level serializable",
452+
want: ParsedBeginStatement{
453+
query: "begin transaction isolation level serializable",
454+
Identifiers: []Identifier{{Parts: []string{"isolation_level"}}},
455+
Literals: []Literal{{Value: "serializable"}},
456+
},
457+
},
458+
{
459+
input: "begin transaction isolation level repeatable read, read write",
460+
want: ParsedBeginStatement{
461+
query: "begin transaction isolation level repeatable read, read write",
462+
Identifiers: []Identifier{{Parts: []string{"isolation_level"}}, {Parts: []string{"transaction_read_only"}}},
463+
Literals: []Literal{{Value: "repeatable_read"}, {Value: "false"}},
464+
},
465+
},
434466
}
435467
parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
436468
if err != nil {
@@ -454,7 +486,7 @@ func TestParseBeginStatementGoogleSQL(t *testing.T) {
454486
t.Fatalf("parseStatement(%q) should have returned a *parsedBeginStatement", test.input)
455487
}
456488
if !reflect.DeepEqual(*showStmt, test.want) {
457-
t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want)
489+
t.Errorf("parseStatement(%q) mismatch\n Got: %v\nWant: %v", test.input, *showStmt, test.want)
458490
}
459491
}
460492
})
@@ -506,6 +538,56 @@ func TestParseBeginStatementPostgreSQL(t *testing.T) {
506538
query: "start work",
507539
},
508540
},
541+
{
542+
input: "start work read only",
543+
want: ParsedBeginStatement{
544+
query: "start work read only",
545+
Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}},
546+
Literals: []Literal{{Value: "true"}},
547+
},
548+
},
549+
{
550+
input: "begin read write",
551+
want: ParsedBeginStatement{
552+
query: "begin read write",
553+
Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}},
554+
Literals: []Literal{{Value: "false"}},
555+
},
556+
},
557+
{
558+
input: "begin read write, isolation level repeatable read",
559+
want: ParsedBeginStatement{
560+
query: "begin read write, isolation level repeatable read",
561+
Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}, {Parts: []string{"isolation_level"}}},
562+
Literals: []Literal{{Value: "false"}, {Value: "repeatable_read"}},
563+
},
564+
},
565+
{
566+
// Note that it is possible to set multiple conflicting transaction options in one statement.
567+
// This statement for example sets the transaction to both read/write and read-only.
568+
// The last option will take precedence, as these options are essentially the same as executing the
569+
// following statements sequentially after the BEGIN TRANSACTION statement:
570+
// SET TRANSACTION READ WRITE
571+
// SET TRANSACTION ISOLATION LEVEL REPEATABLE READ
572+
// SET TRANSACTION READ ONLY
573+
// SET TRANSACTION DEFERRABLE
574+
input: "begin transaction \nread write,\nisolation level repeatable read\nread only\ndeferrable",
575+
want: ParsedBeginStatement{
576+
query: "begin transaction \nread write,\nisolation level repeatable read\nread only\ndeferrable",
577+
Identifiers: []Identifier{
578+
{Parts: []string{"transaction_read_only"}},
579+
{Parts: []string{"isolation_level"}},
580+
{Parts: []string{"transaction_read_only"}},
581+
{Parts: []string{"transaction_deferrable"}},
582+
},
583+
Literals: []Literal{
584+
{Value: "false"},
585+
{Value: "repeatable_read"},
586+
{Value: "true"},
587+
{Value: "true"},
588+
},
589+
},
590+
},
509591
{
510592
input: "start foo",
511593
wantErr: true,
@@ -541,7 +623,7 @@ func TestParseBeginStatementPostgreSQL(t *testing.T) {
541623
t.Fatalf("parseStatement(%q) should have returned a *parsedBeginStatement", test.input)
542624
}
543625
if !reflect.DeepEqual(*showStmt, test.want) {
544-
t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want)
626+
t.Errorf("parseStatement(%q) mismatch\n Got: %v\nWant: %v", test.input, *showStmt, test.want)
545627
}
546628
}
547629
})

statements.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,19 @@ type executableBeginStatement struct {
279279
}
280280

281281
func (s *executableBeginStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) {
282+
if len(s.stmt.Identifiers) != len(s.stmt.Literals) {
283+
return nil, status.Errorf(codes.InvalidArgument, "statement contains %d identifiers, but %d values given", len(s.stmt.Identifiers), len(s.stmt.Literals))
284+
}
282285
_, err := c.BeginTx(ctx, driver.TxOptions{})
283286
if err != nil {
284287
return nil, err
285288
}
289+
for index := range s.stmt.Identifiers {
290+
if err := c.setConnectionVariable(s.stmt.Identifiers[index], s.stmt.Literals[index].Value /*IsLocal=*/, true /*IsTransaction=*/, true); err != nil {
291+
return nil, err
292+
}
293+
}
294+
286295
return driver.ResultNoRows, nil
287296
}
288297

0 commit comments

Comments
 (0)