Skip to content

Commit 166398d

Browse files
kyleconroyclaude
andauthored
feat(sqlite): add SQLite support to format tests (#4207)
Implements comprehensive formatting support for SQLite SQL statements, enabling round-trip format testing similar to MySQL. Key changes: **New SQLite Formatter:** - Add format.go implementing Formatter interface for SQLite dialect - Support SQLite-specific named parameter syntax (`:name` instead of `@name`) **AST Formatting Fixes:** - Add `NamedParam` method to Formatter interface for dialect-specific named params - Fix `CollateExpr` missing Format method for COLLATE clause support - Add `DefaultValues` field to InsertStmt for INSERT DEFAULT VALUES syntax - Fix table function argument conversion (e.g., json_each()) **SQLite Parser Improvements:** - Fix BoolExpr Boolop field for AND/OR expressions - Add EXISTS/NOT EXISTS subquery handling via SubLink nodes - Add unary expression support (NOT operator) - Fix NULL literal conversion in COALESCE expressions **Test Updates:** - Add SQLite case to fmt_test.go with case-insensitive fingerprinting - Regenerate expected output for select_exists/select_not_exists tests (EXISTS now correctly returns bool instead of int64) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <noreply@anthropic.com>
1 parent a9f7eae commit 166398d

File tree

13 files changed

+251
-48
lines changed

13 files changed

+251
-48
lines changed

internal/endtoend/fmt_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/sqlc-dev/sqlc/internal/debug"
1414
"github.com/sqlc-dev/sqlc/internal/engine/dolphin"
1515
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
16+
"github.com/sqlc-dev/sqlc/internal/engine/sqlite"
1617
"github.com/sqlc-dev/sqlc/internal/sql/ast"
1718
"github.com/sqlc-dev/sqlc/internal/sql/format"
1819
)
@@ -79,6 +80,22 @@ func TestFormat(t *testing.T) {
7980
}
8081
return ast.Format(stmts[0].Raw, mysqlParser), nil
8182
}
83+
case config.EngineSQLite:
84+
sqliteParser := sqlite.NewParser()
85+
parse = sqliteParser
86+
formatter = sqliteParser
87+
// For SQLite, we use the same "round-trip" fingerprint strategy as MySQL:
88+
// parse the SQL, format it, and return the formatted string.
89+
fingerprint = func(sql string) (string, error) {
90+
stmts, err := sqliteParser.Parse(strings.NewReader(sql))
91+
if err != nil {
92+
return "", err
93+
}
94+
if len(stmts) == 0 {
95+
return "", nil
96+
}
97+
return strings.ToLower(ast.Format(stmts[0].Raw, sqliteParser)), nil
98+
}
8299
default:
83100
// Skip unsupported engines
84101
return

internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/engine/dolphin/format.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ func (p *Parser) Param(n int) string {
2929
return "?"
3030
}
3131

32+
// NamedParam returns the named parameter placeholder for the given name.
33+
// MySQL doesn't have native named parameters, so we use ? (positional).
34+
// The actual parameter names are handled by sqlc's rewrite phase.
35+
func (p *Parser) NamedParam(name string) string {
36+
return "?"
37+
}
38+
3239
// Cast returns a type cast expression.
3340
// MySQL uses CAST(expr AS type) syntax.
3441
func (p *Parser) Cast(arg, typeName string) string {

internal/engine/postgresql/reserved.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ func (p *Parser) Param(n int) string {
6464
return fmt.Sprintf("$%d", n)
6565
}
6666

67+
// NamedParam returns the named parameter placeholder for the given name.
68+
// PostgreSQL/sqlc uses @name syntax.
69+
func (p *Parser) NamedParam(name string) string {
70+
return "@" + name
71+
}
72+
6773
// Cast returns a type cast expression.
6874
// PostgreSQL uses expr::type syntax.
6975
func (p *Parser) Cast(arg, typeName string) string {

internal/engine/sqlite/convert.go

Lines changed: 118 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,10 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No
514514
limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt())
515515
selectStmt.LimitCount = limitCount
516516
selectStmt.LimitOffset = limitOffset
517-
selectStmt.WithClause = &ast.WithClause{Ctes: &ctes}
517+
// Only set WithClause if there are CTEs
518+
if len(ctes.Items) > 0 {
519+
selectStmt.WithClause = &ast.WithClause{Ctes: &ctes}
520+
}
518521
return selectStmt
519522
}
520523

@@ -759,6 +762,13 @@ func (c *cc) convertLiteral(n *parser.Expr_literalContext) ast.Node {
759762
Location: n.GetStart().GetStart(),
760763
}
761764
}
765+
766+
if literal.NULL_() != nil {
767+
return &ast.A_Const{
768+
Val: &ast.Null{},
769+
Location: n.GetStart().GetStart(),
770+
}
771+
}
762772
}
763773
return todo("convertLiteral", n)
764774
}
@@ -776,8 +786,14 @@ func (c *cc) convertBinaryNode(n *parser.Expr_binaryContext) ast.Node {
776786
}
777787

778788
func (c *cc) convertBoolNode(n *parser.Expr_boolContext) ast.Node {
789+
var op ast.BoolExprType
790+
if n.AND_() != nil {
791+
op = ast.BoolExprTypeAnd
792+
} else if n.OR_() != nil {
793+
op = ast.BoolExprTypeOr
794+
}
779795
return &ast.BoolExpr{
780-
// TODO: Set op
796+
Boolop: op,
781797
Args: &ast.List{
782798
Items: []ast.Node{
783799
c.convert(n.Expr(0)),
@@ -787,6 +803,49 @@ func (c *cc) convertBoolNode(n *parser.Expr_boolContext) ast.Node {
787803
}
788804
}
789805

806+
func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node {
807+
op := n.Unary_operator()
808+
if op == nil {
809+
return c.convert(n.Expr())
810+
}
811+
812+
// Get the inner expression
813+
expr := c.convert(n.Expr())
814+
815+
// Check the operator type
816+
if opCtx, ok := op.(*parser.Unary_operatorContext); ok {
817+
if opCtx.NOT_() != nil {
818+
// NOT expression
819+
return &ast.BoolExpr{
820+
Boolop: ast.BoolExprTypeNot,
821+
Args: &ast.List{
822+
Items: []ast.Node{expr},
823+
},
824+
}
825+
}
826+
if opCtx.MINUS() != nil {
827+
// Negative number: -expr
828+
return &ast.A_Expr{
829+
Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}},
830+
Rexpr: expr,
831+
}
832+
}
833+
if opCtx.PLUS() != nil {
834+
// Positive number: +expr (just return expr)
835+
return expr
836+
}
837+
if opCtx.TILDE() != nil {
838+
// Bitwise NOT: ~expr
839+
return &ast.A_Expr{
840+
Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}},
841+
Rexpr: expr,
842+
}
843+
}
844+
}
845+
846+
return expr
847+
}
848+
790849
func (c *cc) convertParam(n *parser.Expr_bindContext) ast.Node {
791850
if n.NUMBERED_BIND_PARAMETER() != nil {
792851
// Parameter numbers start at one
@@ -816,7 +875,52 @@ func (c *cc) convertParam(n *parser.Expr_bindContext) ast.Node {
816875
}
817876

818877
func (c *cc) convertInSelectNode(n *parser.Expr_in_selectContext) ast.Node {
819-
return c.convert(n.Select_stmt())
878+
// Check if this is EXISTS or NOT EXISTS
879+
if n.EXISTS_() != nil {
880+
linkType := ast.EXISTS_SUBLINK
881+
sublink := &ast.SubLink{
882+
SubLinkType: linkType,
883+
Subselect: c.convert(n.Select_stmt()),
884+
}
885+
if n.NOT_() != nil {
886+
// NOT EXISTS is represented as a BoolExpr NOT wrapping the EXISTS
887+
return &ast.BoolExpr{
888+
Boolop: ast.BoolExprTypeNot,
889+
Args: &ast.List{
890+
Items: []ast.Node{sublink},
891+
},
892+
}
893+
}
894+
return sublink
895+
}
896+
897+
// Check if this is an IN/NOT IN expression: expr IN (SELECT ...)
898+
if n.IN_() != nil && len(n.AllExpr()) > 0 {
899+
linkType := ast.ANY_SUBLINK
900+
sublink := &ast.SubLink{
901+
SubLinkType: linkType,
902+
Testexpr: c.convert(n.Expr(0)),
903+
Subselect: c.convert(n.Select_stmt()),
904+
}
905+
if n.NOT_() != nil {
906+
return &ast.A_Expr{
907+
Kind: ast.A_Expr_Kind_OP,
908+
Name: &ast.List{Items: []ast.Node{&ast.String{Str: "NOT IN"}}},
909+
Lexpr: c.convert(n.Expr(0)),
910+
Rexpr: &ast.SubLink{
911+
SubLinkType: ast.EXPR_SUBLINK,
912+
Subselect: c.convert(n.Select_stmt()),
913+
},
914+
}
915+
}
916+
return sublink
917+
}
918+
919+
// Plain subquery in parentheses (SELECT ...)
920+
return &ast.SubLink{
921+
SubLinkType: ast.EXPR_SUBLINK,
922+
Subselect: c.convert(n.Select_stmt()),
923+
}
820924
}
821925

822926
func (c *cc) convertReturning_caluseContext(n parser.IReturning_clauseContext) *ast.List {
@@ -887,12 +991,8 @@ func (c *cc) convertInsert_stmtContext(n *parser.Insert_stmtContext) ast.Node {
887991
}
888992

889993
if hasDefaultValues {
890-
// For DEFAULT VALUES, create an empty select statement
891-
insert.SelectStmt = &ast.SelectStmt{
892-
FromClause: &ast.List{},
893-
TargetList: &ast.List{},
894-
ValuesLists: &ast.List{Items: []ast.Node{&ast.List{}}}, // Single empty values list
895-
}
994+
// For DEFAULT VALUES, set the flag instead of creating an empty values list
995+
insert.DefaultValues = true
896996
} else if n.Select_stmt() != nil {
897997
if ss, ok := c.convert(n.Select_stmt()).(*ast.SelectStmt); ok {
898998
ss.ValuesLists = &ast.List{}
@@ -976,6 +1076,11 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast
9761076
tables = append(tables, rv)
9771077
} else if from.Table_function_name() != nil {
9781078
rel := from.Table_function_name().GetText()
1079+
// Convert function arguments
1080+
var args []ast.Node
1081+
for _, expr := range from.AllExpr() {
1082+
args = append(args, c.convert(expr))
1083+
}
9791084
rf := &ast.RangeFunction{
9801085
Functions: &ast.List{
9811086
Items: []ast.Node{
@@ -989,7 +1094,7 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast
9891094
},
9901095
},
9911096
Args: &ast.List{
992-
Items: []ast.Node{&ast.TODO{}},
1097+
Items: args,
9931098
},
9941099
Location: from.GetStart().GetStart(),
9951100
},
@@ -1189,6 +1294,9 @@ func (c *cc) convert(node node) ast.Node {
11891294
case *parser.Expr_binaryContext:
11901295
return c.convertBinaryNode(n)
11911296

1297+
case *parser.Expr_unaryContext:
1298+
return c.convertUnaryExpr(n)
1299+
11921300
case *parser.Expr_in_selectContext:
11931301
return c.convertInSelectNode(n)
11941302

internal/engine/sqlite/format.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package sqlite
2+
3+
// QuoteIdent returns a quoted identifier if it needs quoting.
4+
// SQLite uses double quotes for quoting identifiers (SQL standard),
5+
// though backticks are also supported for MySQL compatibility.
6+
func (p *Parser) QuoteIdent(s string) string {
7+
// For now, don't quote - return as-is
8+
return s
9+
}
10+
11+
// TypeName returns the SQL type name for the given namespace and name.
12+
func (p *Parser) TypeName(ns, name string) string {
13+
if ns != "" {
14+
return ns + "." + name
15+
}
16+
return name
17+
}
18+
19+
// Param returns the parameter placeholder for the given number.
20+
// SQLite uses ? for positional parameters.
21+
func (p *Parser) Param(n int) string {
22+
return "?"
23+
}
24+
25+
// NamedParam returns the named parameter placeholder for the given name.
26+
// SQLite uses :name syntax for named parameters.
27+
func (p *Parser) NamedParam(name string) string {
28+
return ":" + name
29+
}
30+
31+
// Cast returns a type cast expression.
32+
// SQLite uses CAST(expr AS type) syntax.
33+
func (p *Parser) Cast(arg, typeName string) string {
34+
return "CAST(" + arg + " AS " + typeName + ")"
35+
}

internal/sql/ast/a_expr.go

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,36 @@ func (n *A_Expr) Pos() int {
1212
return n.Location
1313
}
1414

15+
// isNamedParam returns true if this A_Expr represents a named parameter (@name)
16+
// and extracts the parameter name if so.
17+
func (n *A_Expr) isNamedParam() (string, bool) {
18+
if n.Name == nil || len(n.Name.Items) != 1 {
19+
return "", false
20+
}
21+
s, ok := n.Name.Items[0].(*String)
22+
if !ok || s.Str != "@" {
23+
return "", false
24+
}
25+
if set(n.Lexpr) || !set(n.Rexpr) {
26+
return "", false
27+
}
28+
if nameStr, ok := n.Rexpr.(*String); ok {
29+
return nameStr.Str, true
30+
}
31+
return "", false
32+
}
33+
1534
func (n *A_Expr) Format(buf *TrackedBuffer) {
1635
if n == nil {
1736
return
1837
}
38+
39+
// Check for named parameter first (works regardless of Kind)
40+
if name, ok := n.isNamedParam(); ok {
41+
buf.WriteString(buf.NamedParam(name))
42+
return
43+
}
44+
1945
switch n.Kind {
2046
case A_Expr_Kind_IN:
2147
buf.astFormat(n.Lexpr)
@@ -64,32 +90,8 @@ func (n *A_Expr) Format(buf *TrackedBuffer) {
6490
buf.WriteString(", ")
6591
buf.astFormat(n.Rexpr)
6692
buf.WriteString(")")
67-
case A_Expr_Kind_OP:
68-
// Check if this is a named parameter (@name)
69-
opName := ""
70-
if n.Name != nil && len(n.Name.Items) == 1 {
71-
if s, ok := n.Name.Items[0].(*String); ok {
72-
opName = s.Str
73-
}
74-
}
75-
if opName == "@" && !set(n.Lexpr) && set(n.Rexpr) {
76-
// Named parameter: @name (no space after @)
77-
buf.WriteString("@")
78-
buf.astFormat(n.Rexpr)
79-
} else {
80-
// Standard binary operator
81-
if set(n.Lexpr) {
82-
buf.astFormat(n.Lexpr)
83-
buf.WriteString(" ")
84-
}
85-
buf.astFormat(n.Name)
86-
if set(n.Rexpr) {
87-
buf.WriteString(" ")
88-
buf.astFormat(n.Rexpr)
89-
}
90-
}
9193
default:
92-
// Fallback for other cases
94+
// Standard operator (including A_Expr_Kind_OP)
9395
if set(n.Lexpr) {
9496
buf.astFormat(n.Lexpr)
9597
buf.WriteString(" ")

0 commit comments

Comments
 (0)