From 3a7b8b8f0b1eaeee787fc38b39c2ac8730e9b9b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Wed, 5 Nov 2025 13:41:08 +0100 Subject: [PATCH] chore: add functions for parsing statement hint values Adds functions for parsing statement hint values from SQL strings. These functions are currently not used by the driver, will eventually be used to allow SQL strings to contain hints that sets connection variables only for the duration of the execution of a statement. That is, the following will eventually be supported: ```sql @{ statement_tag='my_tag', statement_timeout='100ms' } select * from my_table where key=@key ``` The given statement_tag and statement_timeout will only be applied to the current statement. --- parser/simple_parser.go | 65 ++++++-- parser/statement_parser.go | 71 ++++++++- parser/statement_parser_test.go | 259 +++++++++++++++++++++++++++++++- 3 files changed, 378 insertions(+), 17 deletions(-) diff --git a/parser/simple_parser.go b/parser/simple_parser.go index e8fef032..6127c729 100644 --- a/parser/simple_parser.go +++ b/parser/simple_parser.go @@ -19,6 +19,7 @@ import ( "unicode" "unicode/utf8" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -43,16 +44,29 @@ func (p *simpleParser) hasMoreTokens() bool { // // Any whitespace and/or comments before or between the tokens will be skipped. func (p *simpleParser) eatTokens(tokens []byte) bool { + return p.eatTokensWithWhitespaceOption(tokens, true) +} + +// eatTokens advances the parser by len(tokens) positions and returns true +// if the next len(tokens) bytes are equal to the given bytes. This function +// only works for characters that can be encoded in one byte. +// +// Any whitespace and/or comments before or between the tokens will be interpreted as actual tokens. +func (p *simpleParser) eatTokensOnly(tokens []byte) bool { + return p.eatTokensWithWhitespaceOption(tokens, false) +} + +func (p *simpleParser) eatTokensWithWhitespaceOption(tokens []byte, eatWhitespaces bool) bool { if len(tokens) == 0 { return true } if len(tokens) == 1 { - return p.eatToken(tokens[0]) + return p.eatTokenWithWhitespaceOption(tokens[0], eatWhitespaces) } startPos := p.pos for _, t := range tokens { - if !p.eatToken(t) { + if !p.eatTokenWithWhitespaceOption(t, eatWhitespaces) { p.pos = startPos return false } @@ -369,25 +383,48 @@ func (p *simpleParser) readUnquotedLiteral() string { // position until it encounters a non-whitespace / non-comment. // The position of the parser is updated. func (p *simpleParser) skipWhitespacesAndComments() { - p.pos = p.statementParser.skipWhitespacesAndComments(p.sql, p.pos) + p.skipWhitespacesAndCommentsWithPgHintOption( /*skipPgHints = */ true) +} + +func (p *simpleParser) skipWhitespacesAndCommentsWithPgHintOption(skipPgHints bool) { + p.pos = p.statementParser.skipWhitespacesAndComments(p.sql, p.pos, skipPgHints) } -var statementHintPrefix = []byte{'@', '{'} +var googleSqlStatementHintPrefix = []byte{'@', '{'} +var postgreSqlStatementHintPrefix = []byte{'/', '*', '@'} // skipStatementHint skips any statement hint at the start of the statement. -func (p *simpleParser) skipStatementHint() bool { - if p.eatTokens(statementHintPrefix) { - for ; p.pos < len(p.sql); p.pos++ { - // We don't have to worry about an '}' being inside a statement hint - // key or value, as it is not a valid part of an identifier, and - // statement hints have a fixed set of possible values. - if p.sql[p.pos] == '}' { - p.pos++ - return true +// Returns true if a statement hint was actually found, and the start index of that statement hint. +func (p *simpleParser) skipStatementHint() (bool, int) { + if p.statementParser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + // Statement hints in PostgreSQL are encoded in comments of the following form: + // /*@ hint_key=hint_value[, hint_key2=hint_value2[,...]] */ + // Skip all other whitespaces and comments, but not comments that contain a PG hint. + p.skipWhitespacesAndCommentsWithPgHintOption( /*skipPgHints=*/ false) + // Check if the next tokens are a PG hint. + if len(p.sql) > p.pos+2 && p.sql[p.pos] == '/' && p.sql[p.pos+1] == '*' && p.sql[p.pos+2] == '@' { + startPos := p.pos + // Move to the end of this comment. + p.pos = p.statementParser.skipMultiLineComment(p.sql, p.pos) + // Note that this also returns true if the multiline comment is not terminated, and we just reached the + // end of the SQL string. + return true, startPos + } + } else { + startPos := p.pos + if p.eatTokens(googleSqlStatementHintPrefix) { + for ; p.pos < len(p.sql); p.pos++ { + // We don't have to worry about an '}' being inside a statement hint + // key or value, as it is not a valid part of an identifier, and + // statement hints have a fixed set of possible values. + if p.sql[p.pos] == '}' { + p.pos++ + return true, startPos + } } } } - return false + return false, 0 } // isMultibyte returns true if the character at the current position diff --git a/parser/statement_parser.go b/parser/statement_parser.go index d86d8295..686f44e5 100644 --- a/parser/statement_parser.go +++ b/parser/statement_parser.go @@ -279,7 +279,11 @@ func (p *StatementParser) ParseParameters(sql string) (string, []string, error) // Skips all whitespaces from the given position and returns the // position of the next non-whitespace character or len(sql) if // the string does not contain any whitespaces after pos. -func (p *StatementParser) skipWhitespacesAndComments(sql []byte, pos int) int { +// +// PostgreSQL hints are encoded as comments in the following form: +// /*@ hint_key=hint_value[, hint_key2=hint_value2[,...]] */ +// The skipPgHints argument indicates whether those comments should also be skipped or not. +func (p *StatementParser) skipWhitespacesAndComments(sql []byte, pos int, skipPgHints bool) int { for pos < len(sql) { c := sql[pos] if isMultibyte(c) { @@ -293,6 +297,10 @@ func (p *StatementParser) skipWhitespacesAndComments(sql []byte, pos int) int { pos = p.skipSingleLineComment(sql, pos+1) } else if c == '/' && len(sql) > pos+1 && sql[pos+1] == '*' { // This is a multi line comment starting with '/*'. + if !skipPgHints && len(sql) > pos+2 && sql[pos+2] == '@' { + // This is a PostgreSQL hint, and we should not skip it. + break + } pos = p.skipMultiLineComment(sql, pos) } else if !isSpace(c) { break @@ -745,7 +753,7 @@ func (p *StatementParser) DetectStatementType(sql string) *StatementInfo { func (p *StatementParser) calculateDetectStatementType(sql string) *StatementInfo { parser := &simpleParser{sql: []byte(sql), statementParser: p} - _ = parser.skipStatementHint() + _, _ = parser.skipStatementHint() keyword := strings.ToUpper(parser.readKeyword()) if isQueryKeyword(keyword) { return &StatementInfo{StatementType: StatementTypeQuery} @@ -770,3 +778,62 @@ func detectDmlKeyword(keyword string) DmlType { } return DmlTypeUnknown } + +func (p *StatementParser) extractSetStatementsFromHints(sql string) (*ParsedSetStatement, error) { + sp := &simpleParser{sql: []byte(sql), statementParser: p} + if ok, startPos := sp.skipStatementHint(); ok { + // Mark the start and end of the statement hint and extract the values in the hint. + endPos := sp.pos + sp.pos = startPos + if p.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + // eatTokensOnly will only look for the following character sequence: '/*@' + // It will not interpret it as a comment. + sp.eatTokensOnly(postgreSqlStatementHintPrefix) + } else { + sp.eatTokens(googleSqlStatementHintPrefix) + } + // The default is that the hint ends with a single '}'. + endIndex := endPos - 1 + if p.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + // The hint ends with '*/' + endIndex = endPos - 2 + } + if endIndex > sp.pos && endIndex < len(sql) { + return p.extractConnectionVariables(sql[sp.pos:endIndex]) + } + } + return nil, nil +} + +func (p *StatementParser) extractConnectionVariables(sql string) (*ParsedSetStatement, error) { + sp := &simpleParser{sql: []byte(sql), statementParser: p} + statement := &ParsedSetStatement{ + Identifiers: make([]Identifier, 0, 2), + Literals: make([]Literal, 0, 2), + } + for { + if !sp.hasMoreTokens() { + break + } + identifier, err := sp.eatIdentifier() + if err != nil { + return nil, err + } + if !sp.eatToken('=') { + return nil, status.Errorf(codes.InvalidArgument, "missing '=' token after %s in hint", identifier) + } + literal, err := sp.eatLiteral() + if err != nil { + return nil, err + } + statement.Identifiers = append(statement.Identifiers, identifier) + statement.Literals = append(statement.Literals, literal) + if !sp.eatToken(',') { + break + } + } + if sp.hasMoreTokens() { + return nil, status.Errorf(codes.InvalidArgument, "unexpected tokens: %s", string(sp.sql[sp.pos:])) + } + return statement, nil +} diff --git a/parser/statement_parser_test.go b/parser/statement_parser_test.go index 2ab64399..eef2ac58 100644 --- a/parser/statement_parser_test.go +++ b/parser/statement_parser_test.go @@ -21,6 +21,7 @@ import ( "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -1744,7 +1745,7 @@ func TestSkipWhitespaces(t *testing.T) { } for _, test := range tests { t.Run(fmt.Sprintf("%s: %s", dialect, test.name), func(t *testing.T) { - pos := p.skipWhitespacesAndComments([]byte(test.input), 0) + pos := p.skipWhitespacesAndComments([]byte(test.input), 0 /*skipPgHints=*/, true) if g, w := test.input[:pos], test.want; g != w { t.Errorf("skip whitespace mismatch\n Got: %q\nWant: %q", g, w) } @@ -2634,6 +2635,262 @@ func TestEatIdentifier(t *testing.T) { } } +func TestExtractSetStatementsFromHints(t *testing.T) { + t.Parallel() + + parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) + if err != nil { + t.Fatal(err) + } + tests := []struct { + input string + want *ParsedSetStatement + wantErr bool + }{ + { + input: "@{foo='bar'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "/* comment */ @{foo='bar'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "-- comment \n @{foo='bar'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "@{key1='value1', key2='value2'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"key1"}}, {Parts: []string{"key2"}}}, + Literals: []Literal{{"value1"}, {"value2"}}, + }, + }, + { + input: "@{ int_key= 5, string_key =\n 'test'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"int_key"}}, {Parts: []string{"string_key"}}}, + Literals: []Literal{{"5"}, {"test"}}, + }, + }, + { + input: "@{ foo = 'bar', } select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "@{ } select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{}, + Literals: []Literal{}, + }, + }, + { + input: "select * from my_table", + want: nil, + }, + { + input: "@{ foo is 'bar' } select * from my_table", + wantErr: true, + }, + { + input: "@{ foo == 'bar' } select * from my_table", + wantErr: true, + }, + { + input: "@{ foo 'bar' } select * from my_table", + wantErr: true, + }, + { + input: "@{ 'bar' } select * from my_table", + wantErr: true, + }, + { + input: "@{, foo='bar'} select * from my_table", + wantErr: true, + }, + { + input: "@{foo1='bar1' foo2='bar2'} select * from my_table", + wantErr: true, + }, + { + input: "@{'foo'='bar'} select * from my_table", + wantErr: true, + }, + { + // Quoted tokens are not really supported in normal statement hints, + // but the local parser accepts it for connection variables. + input: "@{`foo`='bar'} select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + // Note the misplaced backtick AFTER the '='. + input: "@{`foo=`'bar'} select * from my_table", + wantErr: true, + }, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + statement, err := parser.extractSetStatementsFromHints(test.input) + if test.wantErr { + if err == nil { + t.Fatal("missing expected error") + } + } else { + if err != nil { + t.Fatal(err) + } + opts := cmpopts.IgnoreUnexported(ParsedSetStatement{}) + if !cmp.Equal(statement, test.want, opts) { + t.Fatalf("mismatch (-want +got):\n%s", cmp.Diff(test.want, statement, opts)) + } + } + }) + } +} + +func TestExtractSetStatementsFromHintsPostgreSQL(t *testing.T) { + t.Parallel() + + parser, err := NewStatementParser(databasepb.DatabaseDialect_POSTGRESQL, 1000) + if err != nil { + t.Fatal(err) + } + tests := []struct { + input string + want *ParsedSetStatement + wantErr bool + }{ + { + input: "/*@foo='bar'*/ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "/* comment */ /*@foo='bar'*/ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "-- comment \n /*@foo='bar'*/ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "/*@key1='value1', key2='value2'*/ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"key1"}}, {Parts: []string{"key2"}}}, + Literals: []Literal{{"value1"}, {"value2"}}, + }, + }, + { + input: "/*@ int_key= 5, string_key =\n 'test'*/ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"int_key"}}, {Parts: []string{"string_key"}}}, + Literals: []Literal{{"5"}, {"test"}}, + }, + }, + { + input: "/*@ foo = 'bar', */ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + input: "/*@ */ select * from my_table", + want: &ParsedSetStatement{ + Identifiers: []Identifier{}, + Literals: []Literal{}, + }, + }, + { + input: "select * from my_table", + want: nil, + }, + { + input: "/*@ foo is 'bar' */ select * from my_table", + wantErr: true, + }, + { + input: "/*@ foo == 'bar' */ select * from my_table", + wantErr: true, + }, + { + input: "/*@ foo 'bar' */ select * from my_table", + wantErr: true, + }, + { + input: "/*@ 'bar' */ select * from my_table", + wantErr: true, + }, + { + input: "/*@, foo='bar'*/ select * from my_table", + wantErr: true, + }, + { + input: "/*@foo1='bar1' foo2='bar2'*/ select * from my_table", + wantErr: true, + }, + { + input: "/*@'foo'='bar'/* select * from my_table", + wantErr: true, + }, + { + // Quoted tokens are not really supported in normal statement hints, + // but the local parser accepts it for connection variables. + input: `/*@"foo"='bar'*/ select * from my_table`, + want: &ParsedSetStatement{ + Identifiers: []Identifier{{Parts: []string{"foo"}}}, + Literals: []Literal{{"bar"}}, + }, + }, + { + // Note the misplaced backtick AFTER the '='. + input: "/*@`foo=`'bar'*/ select * from my_table", + wantErr: true, + }, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + statement, err := parser.extractSetStatementsFromHints(test.input) + if test.wantErr { + if err == nil { + t.Fatal("missing expected error") + } + } else { + if err != nil { + t.Fatal(err) + } + opts := cmpopts.IgnoreUnexported(ParsedSetStatement{}) + if !cmp.Equal(statement, test.want, opts) { + t.Fatalf("mismatch (-want +got):\n%s", cmp.Diff(test.want, statement, opts)) + } + } + }) + } +} + func BenchmarkDetectStatementTypeWithoutCache(b *testing.B) { parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 0) if err != nil {