Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions parser/simple_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"unicode"
"unicode/utf8"

"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
Expand All @@ -43,16 +44,29 @@ func (p *simpleParser) hasMoreTokens() bool {
//
// Any whitespace and/or comments before or between the tokens will be skipped.
func (p *simpleParser) eatTokens(tokens []byte) bool {
return p.eatTokensWithWhitespaceOption(tokens, true)
}

// eatTokens advances the parser by len(tokens) positions and returns true
// if the next len(tokens) bytes are equal to the given bytes. This function
// only works for characters that can be encoded in one byte.
//
// Any whitespace and/or comments before or between the tokens will be interpreted as actual tokens.
func (p *simpleParser) eatTokensOnly(tokens []byte) bool {
return p.eatTokensWithWhitespaceOption(tokens, false)
}

func (p *simpleParser) eatTokensWithWhitespaceOption(tokens []byte, eatWhitespaces bool) bool {
if len(tokens) == 0 {
return true
}
if len(tokens) == 1 {
return p.eatToken(tokens[0])
return p.eatTokenWithWhitespaceOption(tokens[0], eatWhitespaces)
}

startPos := p.pos
for _, t := range tokens {
if !p.eatToken(t) {
if !p.eatTokenWithWhitespaceOption(t, eatWhitespaces) {
p.pos = startPos
return false
}
Expand Down Expand Up @@ -369,25 +383,48 @@ func (p *simpleParser) readUnquotedLiteral() string {
// position until it encounters a non-whitespace / non-comment.
// The position of the parser is updated.
func (p *simpleParser) skipWhitespacesAndComments() {
p.pos = p.statementParser.skipWhitespacesAndComments(p.sql, p.pos)
p.skipWhitespacesAndCommentsWithPgHintOption( /*skipPgHints = */ true)
}

func (p *simpleParser) skipWhitespacesAndCommentsWithPgHintOption(skipPgHints bool) {
p.pos = p.statementParser.skipWhitespacesAndComments(p.sql, p.pos, skipPgHints)
}

var statementHintPrefix = []byte{'@', '{'}
var googleSqlStatementHintPrefix = []byte{'@', '{'}
var postgreSqlStatementHintPrefix = []byte{'/', '*', '@'}

// skipStatementHint skips any statement hint at the start of the statement.
func (p *simpleParser) skipStatementHint() bool {
if p.eatTokens(statementHintPrefix) {
for ; p.pos < len(p.sql); p.pos++ {
// We don't have to worry about an '}' being inside a statement hint
// key or value, as it is not a valid part of an identifier, and
// statement hints have a fixed set of possible values.
if p.sql[p.pos] == '}' {
p.pos++
return true
// Returns true if a statement hint was actually found, and the start index of that statement hint.
func (p *simpleParser) skipStatementHint() (bool, int) {
if p.statementParser.Dialect == databasepb.DatabaseDialect_POSTGRESQL {
// Statement hints in PostgreSQL are encoded in comments of the following form:
// /*@ hint_key=hint_value[, hint_key2=hint_value2[,...]] */
// Skip all other whitespaces and comments, but not comments that contain a PG hint.
p.skipWhitespacesAndCommentsWithPgHintOption( /*skipPgHints=*/ false)
// Check if the next tokens are a PG hint.
if len(p.sql) > p.pos+2 && p.sql[p.pos] == '/' && p.sql[p.pos+1] == '*' && p.sql[p.pos+2] == '@' {
startPos := p.pos
// Move to the end of this comment.
p.pos = p.statementParser.skipMultiLineComment(p.sql, p.pos)
// Note that this also returns true if the multiline comment is not terminated, and we just reached the
// end of the SQL string.
return true, startPos
}
} else {
startPos := p.pos
if p.eatTokens(googleSqlStatementHintPrefix) {
for ; p.pos < len(p.sql); p.pos++ {
// We don't have to worry about an '}' being inside a statement hint
// key or value, as it is not a valid part of an identifier, and
// statement hints have a fixed set of possible values.
if p.sql[p.pos] == '}' {
p.pos++
return true, startPos
}
}
}
}
return false
return false, 0
}

// isMultibyte returns true if the character at the current position
Expand Down
71 changes: 69 additions & 2 deletions parser/statement_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,11 @@ func (p *StatementParser) ParseParameters(sql string) (string, []string, error)
// Skips all whitespaces from the given position and returns the
// position of the next non-whitespace character or len(sql) if
// the string does not contain any whitespaces after pos.
func (p *StatementParser) skipWhitespacesAndComments(sql []byte, pos int) int {
//
// PostgreSQL hints are encoded as comments in the following form:
// /*@ hint_key=hint_value[, hint_key2=hint_value2[,...]] */
// The skipPgHints argument indicates whether those comments should also be skipped or not.
func (p *StatementParser) skipWhitespacesAndComments(sql []byte, pos int, skipPgHints bool) int {
for pos < len(sql) {
c := sql[pos]
if isMultibyte(c) {
Expand All @@ -293,6 +297,10 @@ func (p *StatementParser) skipWhitespacesAndComments(sql []byte, pos int) int {
pos = p.skipSingleLineComment(sql, pos+1)
} else if c == '/' && len(sql) > pos+1 && sql[pos+1] == '*' {
// This is a multi line comment starting with '/*'.
if !skipPgHints && len(sql) > pos+2 && sql[pos+2] == '@' {
// This is a PostgreSQL hint, and we should not skip it.
break
}
pos = p.skipMultiLineComment(sql, pos)
} else if !isSpace(c) {
break
Expand Down Expand Up @@ -745,7 +753,7 @@ func (p *StatementParser) DetectStatementType(sql string) *StatementInfo {

func (p *StatementParser) calculateDetectStatementType(sql string) *StatementInfo {
parser := &simpleParser{sql: []byte(sql), statementParser: p}
_ = parser.skipStatementHint()
_, _ = parser.skipStatementHint()
keyword := strings.ToUpper(parser.readKeyword())
if isQueryKeyword(keyword) {
return &StatementInfo{StatementType: StatementTypeQuery}
Expand All @@ -770,3 +778,62 @@ func detectDmlKeyword(keyword string) DmlType {
}
return DmlTypeUnknown
}

func (p *StatementParser) extractSetStatementsFromHints(sql string) (*ParsedSetStatement, error) {
sp := &simpleParser{sql: []byte(sql), statementParser: p}
if ok, startPos := sp.skipStatementHint(); ok {
// Mark the start and end of the statement hint and extract the values in the hint.
endPos := sp.pos
sp.pos = startPos
if p.Dialect == databasepb.DatabaseDialect_POSTGRESQL {
// eatTokensOnly will only look for the following character sequence: '/*@'
// It will not interpret it as a comment.
sp.eatTokensOnly(postgreSqlStatementHintPrefix)
} else {
sp.eatTokens(googleSqlStatementHintPrefix)
}
// The default is that the hint ends with a single '}'.
endIndex := endPos - 1
if p.Dialect == databasepb.DatabaseDialect_POSTGRESQL {
// The hint ends with '*/'
endIndex = endPos - 2
}
if endIndex > sp.pos && endIndex < len(sql) {
return p.extractConnectionVariables(sql[sp.pos:endIndex])
}
}
return nil, nil
}

func (p *StatementParser) extractConnectionVariables(sql string) (*ParsedSetStatement, error) {
sp := &simpleParser{sql: []byte(sql), statementParser: p}
statement := &ParsedSetStatement{
Identifiers: make([]Identifier, 0, 2),
Literals: make([]Literal, 0, 2),
}
for {
if !sp.hasMoreTokens() {
break
}
identifier, err := sp.eatIdentifier()
if err != nil {
return nil, err
}
if !sp.eatToken('=') {
return nil, status.Errorf(codes.InvalidArgument, "missing '=' token after %s in hint", identifier)
}
literal, err := sp.eatLiteral()
if err != nil {
return nil, err
}
statement.Identifiers = append(statement.Identifiers, identifier)
statement.Literals = append(statement.Literals, literal)
if !sp.eatToken(',') {
break
}
}
if sp.hasMoreTokens() {
return nil, status.Errorf(codes.InvalidArgument, "unexpected tokens: %s", string(sp.sql[sp.pos:]))
}
return statement, nil
}
Loading
Loading