Skip to content

Commit a9f7eae

Browse files
kyleconroyclaude
andauthored
feat(mysql): improve AST formatting and add DELETE JOIN support (#4206)
This PR continues the MySQL AST formatting work with several improvements: **New AST Nodes:** - `VariableExpr` - MySQL user variables (@var), distinct from sqlc @param - `IntervalExpr` - MySQL INTERVAL expressions - `OnDuplicateKeyUpdate` - MySQL ON DUPLICATE KEY UPDATE clause - `ParenExpr` - Explicit parentheses for expression grouping **DELETE with JOIN Support:** - Extended DeleteStmt with Targets and FromClause fields - Multi-table DELETE now properly formats: DELETE t1.*, t2.* FROM t1 JOIN t2... - Updated compiler/output_columns.go to handle new structure **Bug Fixes:** - MySQL @variable now preserved as-is (not treated as sqlc named parameter) - Column type lengths only output for types where meaningful (varchar, char) - Fixed sqlc.arg() handling in ON DUPLICATE KEY UPDATE clause **Documentation:** - Added CLAUDE.md files documenting AST, astutils, named, rewrite packages - Added CLAUDE.md for dolphin engine with conversion patterns 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <noreply@anthropic.com>
1 parent 405a905 commit a9f7eae

30 files changed

+1304
-70
lines changed

internal/codegen/golang/mysql_type.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.C
6464
}
6565
return "sql.NullInt32"
6666

67-
case "bigint":
67+
case "bigint", "bigint unsigned", "bigint signed":
68+
// "bigint unsigned" and "bigint signed" are MySQL CAST types
69+
// Note: We use int64 for CAST AS UNSIGNED to match original behavior,
70+
// even though uint64 would be more semantically correct.
71+
// The Unsigned flag on columns (from table schema) still uses uint64.
6872
if notNull {
6973
if unsigned {
7074
return "uint64"

internal/compiler/output_columns.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,14 @@ func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, erro
482482
list := &ast.List{}
483483
switch n := node.(type) {
484484
case *ast.DeleteStmt:
485-
list = n.Relations
485+
if n.Relations != nil {
486+
list = n.Relations
487+
} else if n.FromClause != nil {
488+
// Multi-table DELETE: walk FromClause to find tables
489+
var tv tableVisitor
490+
astutils.Walk(&tv, n.FromClause)
491+
list = &tv.list
492+
}
486493
case *ast.InsertStmt:
487494
list = &ast.List{
488495
Items: []ast.Node{n.Relation},

internal/endtoend/fmt_test.go

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,30 @@ package main
33
import (
44
"bytes"
55
"fmt"
6+
"io"
67
"os"
78
"path/filepath"
89
"strings"
910
"testing"
1011

1112
"github.com/sqlc-dev/sqlc/internal/config"
1213
"github.com/sqlc-dev/sqlc/internal/debug"
14+
"github.com/sqlc-dev/sqlc/internal/engine/dolphin"
1315
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
1416
"github.com/sqlc-dev/sqlc/internal/sql/ast"
17+
"github.com/sqlc-dev/sqlc/internal/sql/format"
1518
)
1619

20+
// sqlParser is an interface for SQL parsers
21+
type sqlParser interface {
22+
Parse(r io.Reader) ([]ast.Statement, error)
23+
}
24+
25+
// sqlFormatter is an interface for formatters
26+
type sqlFormatter interface {
27+
format.Formatter
28+
}
29+
1730
func TestFormat(t *testing.T) {
1831
t.Parallel()
1932
for _, tc := range FindTests(t, "testdata", "base") {
@@ -36,9 +49,38 @@ func TestFormat(t *testing.T) {
3649
return
3750
}
3851

39-
// For now, only test PostgreSQL since that's the only engine with Format support
4052
engine := conf.SQL[0].Engine
41-
if engine != config.EnginePostgreSQL {
53+
54+
// Select the appropriate parser and fingerprint function based on engine
55+
var parse sqlParser
56+
var formatter sqlFormatter
57+
var fingerprint func(string) (string, error)
58+
59+
switch engine {
60+
case config.EnginePostgreSQL:
61+
pgParser := postgresql.NewParser()
62+
parse = pgParser
63+
formatter = pgParser
64+
fingerprint = postgresql.Fingerprint
65+
case config.EngineMySQL:
66+
mysqlParser := dolphin.NewParser()
67+
parse = mysqlParser
68+
formatter = mysqlParser
69+
// For MySQL, we use a "round-trip" fingerprint: parse the SQL, format it,
70+
// and return the formatted string. This tests that our formatting produces
71+
// valid SQL that parses to the same AST structure.
72+
fingerprint = func(sql string) (string, error) {
73+
stmts, err := mysqlParser.Parse(strings.NewReader(sql))
74+
if err != nil {
75+
return "", err
76+
}
77+
if len(stmts) == 0 {
78+
return "", nil
79+
}
80+
return ast.Format(stmts[0].Raw, mysqlParser), nil
81+
}
82+
default:
83+
// Skip unsupported engines
4284
return
4385
}
4486

@@ -68,8 +110,6 @@ func TestFormat(t *testing.T) {
68110
return
69111
}
70112

71-
parse := postgresql.NewParser()
72-
73113
for _, queryFile := range queryFiles {
74114
if _, err := os.Stat(queryFile); os.IsNotExist(err) {
75115
continue
@@ -99,7 +139,7 @@ func TestFormat(t *testing.T) {
99139
}
100140
query := strings.TrimSpace(string(contents[start : start+length]))
101141

102-
expected, err := postgresql.Fingerprint(query)
142+
expected, err := fingerprint(query)
103143
if err != nil {
104144
t.Fatal(err)
105145
}
@@ -109,8 +149,8 @@ func TestFormat(t *testing.T) {
109149
debug.Dump(r, err)
110150
}
111151

112-
out := ast.Format(stmt.Raw, parse)
113-
actual, err := postgresql.Fingerprint(out)
152+
out := ast.Format(stmt.Raw, formatter)
153+
actual, err := fingerprint(out)
114154
if err != nil {
115155
t.Error(err)
116156
}

internal/engine/dolphin/CLAUDE.md

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Dolphin Engine (MySQL) - Claude Code Guide
2+
3+
The dolphin engine handles MySQL parsing and AST conversion using the TiDB parser.
4+
5+
## Architecture
6+
7+
### Parser Flow
8+
```
9+
SQL String → TiDB Parser → TiDB AST → sqlc AST → Analysis/Codegen
10+
```
11+
12+
### Key Files
13+
- `convert.go` - Converts TiDB AST nodes to sqlc AST nodes
14+
- `format.go` - MySQL-specific formatting (identifiers, types, parameters)
15+
- `parse.go` - Entry point for parsing MySQL SQL
16+
17+
## TiDB Parser
18+
19+
The TiDB parser (`github.com/pingcap/tidb/pkg/parser`) is used for MySQL parsing:
20+
21+
```go
22+
import (
23+
pcast "github.com/pingcap/tidb/pkg/parser/ast"
24+
"github.com/pingcap/tidb/pkg/parser/mysql"
25+
"github.com/pingcap/tidb/pkg/parser/types"
26+
)
27+
```
28+
29+
### Common TiDB Types
30+
- `pcast.SelectStmt`, `pcast.InsertStmt`, etc. - Statement types
31+
- `pcast.ColumnNameExpr` - Column reference
32+
- `pcast.FuncCallExpr` - Function call
33+
- `pcast.BinaryOperationExpr` - Binary expression
34+
- `pcast.VariableExpr` - MySQL user variable (@var)
35+
- `pcast.Join` - JOIN clause with Left, Right, On, Using
36+
37+
## Conversion Pattern
38+
39+
Each TiDB node type has a corresponding converter method:
40+
41+
```go
42+
func (c *cc) convertSelectStmt(n *pcast.SelectStmt) *ast.SelectStmt {
43+
return &ast.SelectStmt{
44+
FromClause: c.convertTableRefsClause(n.From),
45+
WhereClause: c.convert(n.Where),
46+
// ...
47+
}
48+
}
49+
```
50+
51+
The main `convert()` method dispatches to specific converters:
52+
```go
53+
func (c *cc) convert(node pcast.Node) ast.Node {
54+
switch n := node.(type) {
55+
case *pcast.SelectStmt:
56+
return c.convertSelectStmt(n)
57+
case *pcast.InsertStmt:
58+
return c.convertInsertStmt(n)
59+
// ...
60+
}
61+
}
62+
```
63+
64+
## Key Conversions
65+
66+
### Column References
67+
```go
68+
func (c *cc) convertColumnNameExpr(n *pcast.ColumnNameExpr) *ast.ColumnRef {
69+
var items []ast.Node
70+
if schema := n.Name.Schema.String(); schema != "" {
71+
items = append(items, NewIdentifier(schema))
72+
}
73+
if table := n.Name.Table.String(); table != "" {
74+
items = append(items, NewIdentifier(table))
75+
}
76+
items = append(items, NewIdentifier(n.Name.Name.String()))
77+
return &ast.ColumnRef{Fields: &ast.List{Items: items}}
78+
}
79+
```
80+
81+
### JOINs
82+
```go
83+
func (c *cc) convertJoin(n *pcast.Join) *ast.List {
84+
if n.Right != nil && n.Left != nil {
85+
return &ast.List{
86+
Items: []ast.Node{&ast.JoinExpr{
87+
Jointype: ast.JoinType(n.Tp),
88+
Larg: c.convert(n.Left),
89+
Rarg: c.convert(n.Right),
90+
Quals: c.convert(n.On),
91+
UsingClause: convertUsing(n.Using),
92+
}},
93+
}
94+
}
95+
// No join - just return tables
96+
// ...
97+
}
98+
```
99+
100+
### MySQL User Variables
101+
MySQL user variables (`@var`) are different from sqlc's `@param` syntax:
102+
```go
103+
func (c *cc) convertVariableExpr(n *pcast.VariableExpr) ast.Node {
104+
// Use VariableExpr to preserve as-is (NOT A_Expr which would be treated as sqlc param)
105+
return &ast.VariableExpr{
106+
Name: n.Name,
107+
Location: n.OriginTextPosition(),
108+
}
109+
}
110+
```
111+
112+
### Type Casts (CAST AS)
113+
```go
114+
func (c *cc) convertFuncCastExpr(n *pcast.FuncCastExpr) ast.Node {
115+
typeName := types.TypeStr(n.Tp.GetType())
116+
// Handle UNSIGNED/SIGNED specially
117+
if typeName == "bigint" {
118+
if mysql.HasUnsignedFlag(n.Tp.GetFlag()) {
119+
typeName = "bigint unsigned"
120+
} else {
121+
typeName = "bigint signed"
122+
}
123+
}
124+
return &ast.TypeCast{
125+
Arg: c.convert(n.Expr),
126+
TypeName: &ast.TypeName{Name: typeName},
127+
}
128+
}
129+
```
130+
131+
### Column Definitions
132+
```go
133+
func convertColumnDef(def *pcast.ColumnDef) *ast.ColumnDef {
134+
typeName := &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}
135+
136+
// Only add Typmods for types where length is meaningful
137+
tp := def.Tp.GetType()
138+
flen := def.Tp.GetFlen()
139+
switch tp {
140+
case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString:
141+
if flen >= 0 {
142+
typeName.Typmods = &ast.List{
143+
Items: []ast.Node{&ast.Integer{Ival: int64(flen)}},
144+
}
145+
}
146+
// Don't add for DATETIME, TIMESTAMP - internal flen is not user-specified
147+
}
148+
// ...
149+
}
150+
```
151+
152+
### Multi-Table DELETE
153+
MySQL supports `DELETE t1, t2 FROM t1 JOIN t2 ...`:
154+
```go
155+
func (c *cc) convertDeleteStmt(n *pcast.DeleteStmt) *ast.DeleteStmt {
156+
if n.IsMultiTable && n.Tables != nil {
157+
// Convert targets (t1.*, t2.*)
158+
targets := &ast.List{}
159+
for _, table := range n.Tables.Tables {
160+
// Build ColumnRef for each target
161+
}
162+
stmt.Targets = targets
163+
164+
// Preserve JOINs in FromClause
165+
stmt.FromClause = c.convertTableRefsClause(n.TableRefs).Items[0]
166+
} else {
167+
// Single-table DELETE
168+
stmt.Relations = c.convertTableRefsClause(n.TableRefs)
169+
}
170+
}
171+
```
172+
173+
## MySQL-Specific Formatting
174+
175+
### format.go
176+
```go
177+
func (p *Parser) TypeName(ns, name string) string {
178+
switch name {
179+
case "bigint unsigned":
180+
return "UNSIGNED"
181+
case "bigint signed":
182+
return "SIGNED"
183+
}
184+
return name
185+
}
186+
187+
func (p *Parser) Param(n int) string {
188+
return "?" // MySQL uses ? for all parameters
189+
}
190+
```
191+
192+
## Common Issues and Solutions
193+
194+
### Issue: Panic in Walk/Apply
195+
**Cause**: New AST node type not handled in `astutils/walk.go` or `astutils/rewrite.go`
196+
**Solution**: Add case for the node type in both files
197+
198+
### Issue: sqlc.arg() not converted in ON DUPLICATE KEY UPDATE
199+
**Cause**: `InsertStmt` case in `rewrite.go` didn't traverse `OnDuplicateKeyUpdate`
200+
**Solution**: Add `a.apply(n, "OnDuplicateKeyUpdate", nil, n.OnDuplicateKeyUpdate)`
201+
202+
### Issue: MySQL @variable being treated as parameter
203+
**Cause**: Converting `VariableExpr` to `A_Expr` with `@` operator
204+
**Solution**: Use `ast.VariableExpr` instead, which is not detected by `named.IsParamSign()`
205+
206+
### Issue: Type length appearing incorrectly (e.g., datetime(39))
207+
**Cause**: Using internal `flen` for all types
208+
**Solution**: Only populate `Typmods` for types where length is user-specified (varchar, char, etc.)
209+
210+
## Testing
211+
212+
### TestFormat
213+
Tests that SQL can be:
214+
1. Parsed
215+
2. Formatted back to SQL
216+
3. Re-parsed
217+
4. Re-formatted to match
218+
219+
### TestReplay
220+
Tests the full sqlc pipeline:
221+
1. Parse schema and queries
222+
2. Analyze
223+
3. Generate code
224+
4. Compare with expected output

0 commit comments

Comments
 (0)