diff --git a/parser/simple_parser.go b/parser/simple_parser.go index e8fef032..6127c729 100644 --- a/parser/simple_parser.go +++ b/parser/simple_parser.go @@ -19,6 +19,7 @@ import ( "unicode" "unicode/utf8" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -43,16 +44,29 @@ func (p *simpleParser) hasMoreTokens() bool { // // Any whitespace and/or comments before or between the tokens will be skipped. func (p *simpleParser) eatTokens(tokens []byte) bool { + return p.eatTokensWithWhitespaceOption(tokens, true) +} + +// eatTokens advances the parser by len(tokens) positions and returns true +// if the next len(tokens) bytes are equal to the given bytes. This function +// only works for characters that can be encoded in one byte. +// +// Any whitespace and/or comments before or between the tokens will be interpreted as actual tokens. +func (p *simpleParser) eatTokensOnly(tokens []byte) bool { + return p.eatTokensWithWhitespaceOption(tokens, false) +} + +func (p *simpleParser) eatTokensWithWhitespaceOption(tokens []byte, eatWhitespaces bool) bool { if len(tokens) == 0 { return true } if len(tokens) == 1 { - return p.eatToken(tokens[0]) + return p.eatTokenWithWhitespaceOption(tokens[0], eatWhitespaces) } startPos := p.pos for _, t := range tokens { - if !p.eatToken(t) { + if !p.eatTokenWithWhitespaceOption(t, eatWhitespaces) { p.pos = startPos return false } @@ -369,25 +383,48 @@ func (p *simpleParser) readUnquotedLiteral() string { // position until it encounters a non-whitespace / non-comment. // The position of the parser is updated. func (p *simpleParser) skipWhitespacesAndComments() { - p.pos = p.statementParser.skipWhitespacesAndComments(p.sql, p.pos) + p.skipWhitespacesAndCommentsWithPgHintOption( /*skipPgHints = */ true) +} + +func (p *simpleParser) skipWhitespacesAndCommentsWithPgHintOption(skipPgHints bool) { + p.pos = p.statementParser.skipWhitespacesAndComments(p.sql, p.pos, skipPgHints) } -var statementHintPrefix = []byte{'@', '{'} +var googleSqlStatementHintPrefix = []byte{'@', '{'} +var postgreSqlStatementHintPrefix = []byte{'/', '*', '@'} // skipStatementHint skips any statement hint at the start of the statement. -func (p *simpleParser) skipStatementHint() bool { - if p.eatTokens(statementHintPrefix) { - for ; p.pos < len(p.sql); p.pos++ { - // We don't have to worry about an '}' being inside a statement hint - // key or value, as it is not a valid part of an identifier, and - // statement hints have a fixed set of possible values. - if p.sql[p.pos] == '}' { - p.pos++ - return true +// Returns true if a statement hint was actually found, and the start index of that statement hint. +func (p *simpleParser) skipStatementHint() (bool, int) { + if p.statementParser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + // Statement hints in PostgreSQL are encoded in comments of the following form: + // /*@ hint_key=hint_value[, hint_key2=hint_value2[,...]] */ + // Skip all other whitespaces and comments, but not comments that contain a PG hint. + p.skipWhitespacesAndCommentsWithPgHintOption( /*skipPgHints=*/ false) + // Check if the next tokens are a PG hint. + if len(p.sql) > p.pos+2 && p.sql[p.pos] == '/' && p.sql[p.pos+1] == '*' && p.sql[p.pos+2] == '@' { + startPos := p.pos + // Move to the end of this comment. + p.pos = p.statementParser.skipMultiLineComment(p.sql, p.pos) + // Note that this also returns true if the multiline comment is not terminated, and we just reached the + // end of the SQL string. + return true, startPos + } + } else { + startPos := p.pos + if p.eatTokens(googleSqlStatementHintPrefix) { + for ; p.pos < len(p.sql); p.pos++ { + // We don't have to worry about an '}' being inside a statement hint + // key or value, as it is not a valid part of an identifier, and + // statement hints have a fixed set of possible values. + if p.sql[p.pos] == '}' { + p.pos++ + return true, startPos + } } } } - return false + return false, 0 } // isMultibyte returns true if the character at the current position diff --git a/parser/statement_parser.go b/parser/statement_parser.go index d86d8295..686f44e5 100644 --- a/parser/statement_parser.go +++ b/parser/statement_parser.go @@ -279,7 +279,11 @@ func (p *StatementParser) ParseParameters(sql string) (string, []string, error) // Skips all whitespaces from the given position and returns the // position of the next non-whitespace character or len(sql) if // the string does not contain any whitespaces after pos. -func (p *StatementParser) skipWhitespacesAndComments(sql []byte, pos int) int { +// +// PostgreSQL hints are encoded as comments in the following form: +// /*@ hint_key=hint_value[, hint_key2=hint_value2[,...]] */ +// The skipPgHints argument indicates whether those comments should also be skipped or not. +func (p *StatementParser) skipWhitespacesAndComments(sql []byte, pos int, skipPgHints bool) int { for pos < len(sql) { c := sql[pos] if isMultibyte(c) { @@ -293,6 +297,10 @@ func (p *StatementParser) skipWhitespacesAndComments(sql []byte, pos int) int { pos = p.skipSingleLineComment(sql, pos+1) } else if c == '/' && len(sql) > pos+1 && sql[pos+1] == '*' { // This is a multi line comment starting with '/*'. + if !skipPgHints && len(sql) > pos+2 && sql[pos+2] == '@' { + // This is a PostgreSQL hint, and we should not skip it. + break + } pos = p.skipMultiLineComment(sql, pos) } else if !isSpace(c) { break @@ -745,7 +753,7 @@ func (p *StatementParser) DetectStatementType(sql string) *StatementInfo { func (p *StatementParser) calculateDetectStatementType(sql string) *StatementInfo { parser := &simpleParser{sql: []byte(sql), statementParser: p} - _ = parser.skipStatementHint() + _, _ = parser.skipStatementHint() keyword := strings.ToUpper(parser.readKeyword()) if isQueryKeyword(keyword) { return &StatementInfo{StatementType: StatementTypeQuery} @@ -770,3 +778,62 @@ func detectDmlKeyword(keyword string) DmlType { } return DmlTypeUnknown } + +func (p *StatementParser) extractSetStatementsFromHints(sql string) (*ParsedSetStatement, error) { + sp := &simpleParser{sql: []byte(sql), statementParser: p} + if ok, startPos := sp.skipStatementHint(); ok { + // Mark the start and end of the statement hint and extract the values in the hint. + endPos := sp.pos + sp.pos = startPos + if p.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + // eatTokensOnly will only look for the following character sequence: '/*@' + // It will not interpret it as a comment. + sp.eatTokensOnly(postgreSqlStatementHintPrefix) + } else { + sp.eatTokens(googleSqlStatementHintPrefix) + } + // The default is that the hint ends with a single '}'. + endIndex := endPos - 1 + if p.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + // The hint ends with '*/' + endIndex = endPos - 2 + } + if endIndex > sp.pos && endIndex < len(sql) { + return p.extractConnectionVariables(sql[sp.pos:endIndex]) + } + } + return nil, nil +} + +func (p *StatementParser) extractConnectionVariables(sql string) (*ParsedSetStatement, error) { + sp := &simpleParser{sql: []byte(sql), statementParser: p} + statement := &ParsedSetStatement{ + Identifiers: make([]Identifier, 0, 2), + Literals: make([]Literal, 0, 2), + } + for { + if !sp.hasMoreTokens() { + break + } + identifier, err := sp.eatIdentifier() + if err != nil { + return nil, err + } + if !sp.eatToken('=') { + return nil, status.Errorf(codes.InvalidArgument, "missing '=' token after %s in hint", identifier) + } + literal, err := sp.eatLiteral() + if err != nil { + return nil, err + } + statement.Identifiers = append(statement.Identifiers, identifier) + statement.Literals = append(statement.Literals, literal) + if !sp.eatToken(',') { + break + } + } + if sp.hasMoreTokens() { + return nil, status.Errorf(codes.InvalidArgument, "unexpected tokens: %s", string(sp.sql[sp.pos:])) + } + return statement, nil +} diff --git a/parser/statement_parser_test.go b/parser/statement_parser_test.go index 2ab64399..eef2ac58 100644 --- a/parser/statement_parser_test.go +++ b/parser/statement_parser_test.go @@ -21,6 +21,7 @@ import ( "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -1744,7 +1745,7 @@ func TestSkipWhitespaces(t *testing.T) { } for _, test := range tests { t.Run(fmt.Sprintf("%s: %s", dialect, test.name), func(t *testing.T) { - pos := p.skipWhitespacesAndComments([]byte(test.input), 0) + pos := p.skipWhitespacesAndComments([]byte(test.input), 0 /*skipPgHints=*/, true) if g, w := test.input[:pos], test.want; g != w { t.Errorf("skip whitespace mismatch\n Got: %q\nWant: %q", g, w) } @@ -2634,6 +2635,262 @@ func TestEatIdentifier(t *testing.T) { } } +func TestExtractSetStatementsFromHints(t *testing.T) { + t.Parallel() + + parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) + if err != nil { + t.Fatal(err) + } + tests := []struct { + input string + want *ParsedSetStatement + wantErr bool + }{ + { + input: "@{foo='bar'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "/* comment */ @{foo='bar'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "-- comment \n @{foo='bar'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "@{key1='value1', key2='value2'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"key1"}}, {Parts: []string{"key2"}}}, + Literals: []Literal{{"value1"}, {"value2"}}, + }, + }, + { + input: "@{ int_key= 5, string_key =\n 'test'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"int_key"}}, {Parts: []string{"string_key"}}}, + Literals: []Literal{{"5"}, {"test"}}, + }, + }, + { + input: "@{ foo = 'bar', } select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "@{ } select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{}, + Literals: []Literal{}, + }, + }, + { + input: "select * from my_table", + want: nil, + }, + { + input: "@{ foo is 'bar' } select * from my_table", + wantErr: true, + }, + { + input: "@{ foo == 'bar' } select * from my_table", + wantErr: true, + }, + { + input: "@{ foo 'bar' } select * from my_table", + wantErr: true, + }, + { + input: "@{ 'bar' } select * from my_table", + wantErr: true, + }, + { + input: "@{, foo='bar'} select * from my_table", + wantErr: true, + }, + { + input: "@{foo1='bar1' foo2='bar2'} select * from my_table", + wantErr: true, + }, + { + input: "@{'foo'='bar'} select * from my_table", + wantErr: true, + }, + { + // Quoted tokens are not really supported in normal statement hints, + // but the local parser accepts it for connection variables. + input: "@{`foo`='bar'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + // Note the misplaced backtick AFTER the '='. + input: "@{`foo=`'bar'} select * from my_table", + wantErr: true, + }, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + statement, err := parser.extractSetStatementsFromHints(test.input) + if test.wantErr { + if err == nil { + t.Fatal("missing expected error") + } + } else { + if err != nil { + t.Fatal(err) + } + opts := cmpopts.IgnoreUnexported(ParsedSetStatement{}) + if !cmp.Equal(statement, test.want, opts) { + t.Fatalf("mismatch (-want +got):\n%s", cmp.Diff(test.want, statement, opts)) + } + } + }) + } +} + +func TestExtractSetStatementsFromHintsPostgreSQL(t *testing.T) { + t.Parallel() + + parser, err := NewStatementParser(databasepb.DatabaseDialect_POSTGRESQL, 1000) + if err != nil { + t.Fatal(err) + } + tests := []struct { + input string + want *ParsedSetStatement + wantErr bool + }{ + { + input: "/*@foo='bar'*/ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "/* comment */ /*@foo='bar'*/ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "-- comment \n /*@foo='bar'*/ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "/*@key1='value1', key2='value2'*/ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"key1"}}, {Parts: []string{"key2"}}}, + Literals: []Literal{{"value1"}, {"value2"}}, + }, + }, + { + input: "/*@ int_key= 5, string_key =\n 'test'*/ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"int_key"}}, {Parts: []string{"string_key"}}}, + Literals: []Literal{{"5"}, {"test"}}, + }, + }, + { + input: "/*@ foo = 'bar', */ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "/*@ */ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{}, + Literals: []Literal{}, + }, + }, + { + input: "select * from my_table", + want: nil, + }, + { + input: "/*@ foo is 'bar' */ select * from my_table", + wantErr: true, + }, + { + input: "/*@ foo == 'bar' */ select * from my_table", + wantErr: true, + }, + { + input: "/*@ foo 'bar' */ select * from my_table", + wantErr: true, + }, + { + input: "/*@ 'bar' */ select * from my_table", + wantErr: true, + }, + { + input: "/*@, foo='bar'*/ select * from my_table", + wantErr: true, + }, + { + input: "/*@foo1='bar1' foo2='bar2'*/ select * from my_table", + wantErr: true, + }, + { + input: "/*@'foo'='bar'/* select * from my_table", + wantErr: true, + }, + { + // Quoted tokens are not really supported in normal statement hints, + // but the local parser accepts it for connection variables. + input: `/*@"foo"='bar'*/ select * from my_table`, + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + // Note the misplaced backtick AFTER the '='. + input: "/*@`foo=`'bar'*/ select * from my_table", + wantErr: true, + }, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + statement, err := parser.extractSetStatementsFromHints(test.input) + if test.wantErr { + if err == nil { + t.Fatal("missing expected error") + } + } else { + if err != nil { + t.Fatal(err) + } + opts := cmpopts.IgnoreUnexported(ParsedSetStatement{}) + if !cmp.Equal(statement, test.want, opts) { + t.Fatalf("mismatch (-want +got):\n%s", cmp.Diff(test.want, statement, opts)) + } + } + }) + } +} + func BenchmarkDetectStatementTypeWithoutCache(b *testing.B) { parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 0) if err != nil {