Skip to content

Commit f3e56f0

Browse files
authored
chore: move parser to separate package (#513)
* chore: cache parsed statements and remove regex loading Cache parsed statements and remove loading and checking for regex-based statements, as the latter is no an empty list. * chore: move parser to separate package Move all parsing to a separate package and separate parsed statements from executable statements. This makes the parser reusable for other purposes. Updates #461 * chore: cleanup and add comments
1 parent b035014 commit f3e56f0

16 files changed

+1323
-1170
lines changed

.github/workflows/unit-tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ jobs:
2222
- name: Run connection state unit tests
2323
run: go test -race -short
2424
working-directory: connectionstate
25+
- name: Run parser unit tests
26+
run: go test -race -short
27+
working-directory: parser
2528

2629
lint:
2730
runs-on: ubuntu-latest

client_side_statement_test.go

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,19 @@ import (
2929
"github.com/google/go-cmp/cmp"
3030
"github.com/google/go-cmp/cmp/cmpopts"
3131
"github.com/googleapis/go-sql-spanner/connectionstate"
32+
"github.com/googleapis/go-sql-spanner/parser"
3233
"google.golang.org/grpc/codes"
3334
"google.golang.org/protobuf/types/known/structpb"
3435
)
3536

3637
func TestStatementExecutor_StartBatchDdl(t *testing.T) {
3738
t.Parallel()
3839

39-
parser, _ := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
40+
p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
4041
c := &conn{
4142
logger: noopLogger,
4243
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
43-
parser: parser,
44+
parser: p,
4445
}
4546
ctx := context.Background()
4647

@@ -73,11 +74,11 @@ func TestStatementExecutor_StartBatchDdl(t *testing.T) {
7374
func TestStatementExecutor_StartBatchDml(t *testing.T) {
7475
t.Parallel()
7576

76-
parser, _ := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
77+
p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
7778
c := &conn{
7879
logger: noopLogger,
7980
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
80-
parser: parser,
81+
parser: p,
8182
}
8283
ctx := context.Background()
8384

@@ -116,11 +117,11 @@ func TestStatementExecutor_StartBatchDml(t *testing.T) {
116117
func TestStatementExecutor_RetryAbortsInternally(t *testing.T) {
117118
t.Parallel()
118119

119-
parser, _ := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
120+
p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
120121
c := &conn{
121122
logger: noopLogger,
122123
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
123-
parser: parser,
124+
parser: p,
124125
}
125126
ctx := context.Background()
126127
for i, test := range []struct {
@@ -178,14 +179,14 @@ func TestStatementExecutor_RetryAbortsInternally(t *testing.T) {
178179
func TestStatementExecutor_AutocommitDmlMode(t *testing.T) {
179180
t.Parallel()
180181

181-
parser, _ := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
182+
p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
182183
c := &conn{
183184
logger: noopLogger,
184185
connector: &connector{},
185186
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{
186187
propertyAutocommitDmlMode.Key(): propertyAutocommitDmlMode.CreateTypedInitialValue(Transactional),
187188
}),
188-
parser: parser,
189+
parser: p,
189190
}
190191
_ = c.ResetSession(context.Background())
191192
ctx := context.Background()
@@ -244,11 +245,11 @@ func TestStatementExecutor_AutocommitDmlMode(t *testing.T) {
244245
func TestStatementExecutor_ReadOnlyStaleness(t *testing.T) {
245246
t.Parallel()
246247

247-
parser, _ := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
248+
p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
248249
c := &conn{
249250
logger: noopLogger,
250251
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
251-
parser: parser,
252+
parser: p,
252253
}
253254
ctx := context.Background()
254255
for i, test := range []struct {
@@ -319,10 +320,10 @@ func TestStatementExecutor_ReadOnlyStaleness(t *testing.T) {
319320
func TestShowCommitTimestamp(t *testing.T) {
320321
t.Parallel()
321322

322-
parser, _ := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
323+
p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
323324
c := &conn{
324325
logger: noopLogger,
325-
parser: parser,
326+
parser: p,
326327
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
327328
}
328329
ctx := context.Background()
@@ -371,11 +372,11 @@ func TestShowCommitTimestamp(t *testing.T) {
371372
func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) {
372373
t.Parallel()
373374

374-
parser, _ := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
375+
p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
375376
c := &conn{
376377
logger: noopLogger,
377378
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
378-
parser: parser,
379+
parser: p,
379380
}
380381
ctx := context.Background()
381382
for i, test := range []struct {
@@ -433,11 +434,11 @@ func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) {
433434
func TestStatementExecutor_MaxCommitDelay(t *testing.T) {
434435
t.Parallel()
435436

436-
parser, _ := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
437+
p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
437438
c := &conn{
438439
logger: noopLogger,
439440
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
440-
parser: parser,
441+
parser: p,
441442
}
442443
ctx := context.Background()
443444
for i, test := range []struct {
@@ -505,7 +506,7 @@ func TestStatementExecutor_SetTransactionTag(t *testing.T) {
505506
t.Parallel()
506507

507508
ctx := context.Background()
508-
parser, _ := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
509+
p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
509510
for i, test := range []struct {
510511
wantValue string
511512
setValue string
@@ -521,7 +522,7 @@ func TestStatementExecutor_SetTransactionTag(t *testing.T) {
521522
c := &conn{
522523
logger: noopLogger,
523524
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
524-
parser: parser,
525+
parser: p,
525526
}
526527

527528
it, err := c.QueryContext(ctx, "show variable transaction_tag", []driver.NamedValue{})
@@ -583,18 +584,22 @@ func TestStatementExecutor_UsesExecOptions(t *testing.T) {
583584
t.Parallel()
584585

585586
ctx := context.Background()
586-
parser, _ := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
587+
p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
587588
c := &conn{
588589
logger: noopLogger,
589590
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
590-
parser: parser,
591+
parser: p,
591592
}
592593

593-
clientStmt, err := c.parser.parseClientSideStatement(c, "show variable read_only_staleness")
594+
clientStmt, err := c.parser.ParseClientSideStatement("show variable read_only_staleness")
594595
if err != nil {
595596
t.Fatal(err)
596597
}
597-
it, err := clientStmt.QueryContext(ctx, &ExecOptions{DecodeOption: DecodeOptionProto, ReturnResultSetMetadata: true, ReturnResultSetStats: true}, []driver.NamedValue{})
598+
execStmt, err := createExecutableStatement(clientStmt)
599+
if err != nil {
600+
t.Fatal(err)
601+
}
602+
it, err := execStmt.queryContext(ctx, c, &ExecOptions{DecodeOption: DecodeOptionProto, ReturnResultSetMetadata: true, ReturnResultSetStats: true})
598603
if err != nil {
599604
t.Fatalf("could not get current staleness value from connection: %v", err)
600605
}

conn.go

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
2929
"cloud.google.com/go/spanner/apiv1/spannerpb"
3030
"github.com/googleapis/go-sql-spanner/connectionstate"
31+
"github.com/googleapis/go-sql-spanner/parser"
3132
"google.golang.org/api/iterator"
3233
"google.golang.org/grpc/codes"
3334
"google.golang.org/grpc/status"
@@ -245,7 +246,7 @@ type SpannerConn interface {
245246
var _ SpannerConn = &conn{}
246247

247248
type conn struct {
248-
parser *statementParser
249+
parser *parser.StatementParser
249250
connector *connector
250251
closed bool
251252
client *spanner.Client
@@ -259,7 +260,7 @@ type conn struct {
259260

260261
execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator
261262
execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error)
262-
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error)
263+
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error)
263264
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options *ExecOptions) (int64, error)
264265

265266
// state contains the current ConnectionState for this connection.
@@ -314,15 +315,15 @@ func (c *conn) setCommitResponse(commitResponse *spanner.CommitResponse) {
314315
_ = propertyCommitTimestamp.SetValue(c.state, &commitResponse.CommitTs, connectionstate.ContextUser)
315316
}
316317

317-
func (c *conn) showConnectionVariable(identifier identifier) (any, bool, error) {
318+
func (c *conn) showConnectionVariable(identifier parser.Identifier) (any, bool, error) {
318319
extension, name, err := toExtensionAndName(identifier)
319320
if err != nil {
320321
return nil, false, err
321322
}
322323
return c.state.GetValue(extension, name)
323324
}
324325

325-
func (c *conn) setConnectionVariable(identifier identifier, value string, local bool) error {
326+
func (c *conn) setConnectionVariable(identifier parser.Identifier, value string, local bool) error {
326327
extension, name, err := toExtensionAndName(identifier)
327328
if err != nil {
328329
return err
@@ -333,15 +334,15 @@ func (c *conn) setConnectionVariable(identifier identifier, value string, local
333334
return c.state.SetValue(extension, name, value, connectionstate.ContextUser)
334335
}
335336

336-
func toExtensionAndName(identifier identifier) (string, string, error) {
337+
func toExtensionAndName(identifier parser.Identifier) (string, string, error) {
337338
var extension string
338339
var name string
339-
if len(identifier.parts) == 1 {
340+
if len(identifier.Parts) == 1 {
340341
extension = ""
341-
name = identifier.parts[0]
342-
} else if len(identifier.parts) == 2 {
343-
extension = identifier.parts[0]
344-
name = identifier.parts[1]
342+
name = identifier.Parts[0]
343+
} else if len(identifier.Parts) == 2 {
344+
extension = identifier.Parts[0]
345+
name = identifier.Parts[1]
345346
} else {
346347
return "", "", status.Errorf(codes.InvalidArgument, "invalid variable name: %s", identifier)
347348
}
@@ -796,7 +797,7 @@ func (c *conn) Prepare(query string) (driver.Stmt, error) {
796797

797798
func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
798799
execOptions := c.options( /* reset = */ true)
799-
parsedSQL, args, err := c.parser.parseParameters(query)
800+
parsedSQL, args, err := c.parser.ParseParameters(query)
800801
if err != nil {
801802
return nil, err
802803
}
@@ -805,13 +806,17 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
805806

806807
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
807808
// Execute client side statement if it is one.
808-
clientStmt, err := c.parser.parseClientSideStatement(c, query)
809+
clientStmt, err := c.parser.ParseClientSideStatement(query)
809810
if err != nil {
810811
return nil, err
811812
}
812813
execOptions := c.options( /* reset = */ clientStmt == nil)
813814
if clientStmt != nil {
814-
return clientStmt.QueryContext(ctx, execOptions, args)
815+
execStmt, err := createExecutableStatement(clientStmt)
816+
if err != nil {
817+
return nil, err
818+
}
819+
return execStmt.queryContext(ctx, c, execOptions)
815820
}
816821

817822
return c.queryContext(ctx, query, execOptions, args)
@@ -830,14 +835,14 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
830835
if err != nil {
831836
return nil, err
832837
}
833-
statementType := c.parser.detectStatementType(query)
838+
statementType := c.parser.DetectStatementType(query)
834839
// DDL statements are not supported in QueryContext so fail early.
835-
if statementType.statementType == statementTypeDdl {
840+
if statementType.StatementType == parser.StatementTypeDdl {
836841
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "QueryContext does not support DDL statements, use ExecContext instead"))
837842
}
838843
var iter rowIterator
839844
if c.tx == nil {
840-
if statementType.statementType == statementTypeDml {
845+
if statementType.StatementType == parser.StatementTypeDml {
841846
// Use a read/write transaction to execute the statement.
842847
var commitResponse *spanner.CommitResponse
843848
iter, commitResponse, err = c.execSingleQueryTransactional(ctx, c.client, stmt, execOptions)
@@ -878,13 +883,17 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
878883

879884
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
880885
// Execute client side statement if it is one.
881-
stmt, err := c.parser.parseClientSideStatement(c, query)
886+
stmt, err := c.parser.ParseClientSideStatement(query)
882887
if err != nil {
883888
return nil, err
884889
}
885890
execOptions := c.options( /*reset = */ stmt == nil)
886891
if stmt != nil {
887-
return stmt.ExecContext(ctx, execOptions, args)
892+
execStmt, err := createExecutableStatement(stmt)
893+
if err != nil {
894+
return nil, err
895+
}
896+
return execStmt.execContext(ctx, c, execOptions)
888897
}
889898
return c.execContext(ctx, query, execOptions, args)
890899
}
@@ -893,9 +902,9 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecO
893902
// Clear the commit timestamp of this connection before we execute the statement.
894903
c.clearCommitResponse()
895904

896-
statementInfo := c.parser.detectStatementType(query)
905+
statementInfo := c.parser.DetectStatementType(query)
897906
// Use admin API if DDL statement is provided.
898-
if statementInfo.statementType == statementTypeDdl {
907+
if statementInfo.StatementType == parser.StatementTypeDdl {
899908
// Spanner does not support DDL in transactions, and although it is technically possible to execute DDL
900909
// statements while a transaction is active, we return an error to avoid any confusion whether the DDL
901910
// statement is executed as part of the active transaction or not.
@@ -1297,7 +1306,7 @@ func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement s
12971306

12981307
var errInvalidDmlForExecContext = spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "Exec and ExecContext can only be used with INSERT statements with a THEN RETURN clause that return exactly one row with one column of type INT64. Use Query or QueryContext for DML statements other than INSERT and/or with THEN RETURN clauses that return other/more data."))
12991308

1300-
func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) {
1309+
func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) {
13011310
var res *result
13021311
options.QueryOptions.LastStatement = true
13031312
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
@@ -1315,7 +1324,7 @@ func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement sp
13151324
return res, &resp, nil
13161325
}
13171326

1318-
func execTransactionalDML(ctx context.Context, tx spannerTransaction, statement spanner.Statement, statementInfo *statementInfo, options spanner.QueryOptions) (*result, error) {
1327+
func execTransactionalDML(ctx context.Context, tx spannerTransaction, statement spanner.Statement, statementInfo *parser.StatementInfo, options spanner.QueryOptions) (*result, error) {
13191328
var rowsAffected int64
13201329
var lastInsertId int64
13211330
var hasLastInsertId bool
@@ -1327,7 +1336,7 @@ func execTransactionalDML(ctx context.Context, tx spannerTransaction, statement
13271336
}
13281337
if len(it.Metadata.RowType.Fields) != 0 && !(len(it.Metadata.RowType.Fields) == 1 &&
13291338
it.Metadata.RowType.Fields[0].Type.Code == spannerpb.TypeCode_INT64 &&
1330-
statementInfo.dmlType == dmlTypeInsert) {
1339+
statementInfo.DmlType == parser.DmlTypeInsert) {
13311340
return nil, errInvalidDmlForExecContext
13321341
}
13331342
if err != iterator.Done {

conn_with_mockserver_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1584,7 +1584,7 @@ func getDialect(c *sql.Conn) (dialect databasepb.DatabaseDialect) {
15841584
_ = c.Raw(func(driverConn any) error {
15851585
sc, _ := driverConn.(SpannerConn)
15861586
conn := sc.(*conn)
1587-
dialect = conn.parser.dialect
1587+
dialect = conn.parser.Dialect
15881588
return nil
15891589
})
15901590
return

driver.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import (
4040
"github.com/google/uuid"
4141
"github.com/googleapis/gax-go/v2"
4242
"github.com/googleapis/go-sql-spanner/connectionstate"
43+
"github.com/googleapis/go-sql-spanner/parser"
4344
"google.golang.org/api/iterator"
4445
"google.golang.org/api/option"
4546
"google.golang.org/api/option/internaloption"
@@ -529,7 +530,7 @@ type connector struct {
529530
adminClient *adminapi.DatabaseAdminClient
530531
adminClientErr error
531532
connCount int32
532-
parser *statementParser
533+
parser *parser.StatementParser
533534
}
534535

535536
func newOrCachedConnector(d *Driver, dsn string) (*connector, error) {
@@ -732,7 +733,7 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) {
732733
logger := c.logger.With("connId", connId)
733734
connectionStateType := c.connectorConfig.ConnectionStateType
734735
if connectionStateType == connectionstate.TypeDefault {
735-
if c.parser.dialect == databasepb.DatabaseDialect_POSTGRESQL {
736+
if c.parser.Dialect == databasepb.DatabaseDialect_POSTGRESQL {
736737
connectionStateType = connectionstate.TypeTransactional
737738
} else {
738739
connectionStateType = connectionstate.TypeNonTransactional
@@ -806,7 +807,7 @@ func (c *connector) increaseConnCount(ctx context.Context, databaseName string,
806807
} else if c.connectorConfig.StatementCacheSize == 0 {
807808
cacheSize = defaultStatementCacheSize
808809
}
809-
c.parser, err = newStatementParser(dialect, cacheSize)
810+
c.parser, err = parser.NewStatementParser(dialect, cacheSize)
810811
if err != nil {
811812
closeClient()
812813
return err

0 commit comments

Comments
 (0)