diff --git a/examples/type-inference-test/mysql/db.go b/examples/type-inference-test/mysql/db.go new file mode 100644 index 0000000000..2840133cd4 --- /dev/null +++ b/examples/type-inference-test/mysql/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package mysql + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/examples/type-inference-test/mysql/models.go b/examples/type-inference-test/mysql/models.go new file mode 100644 index 0000000000..4fd3a20ffe --- /dev/null +++ b/examples/type-inference-test/mysql/models.go @@ -0,0 +1,16 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package mysql + +import ( + "database/sql" +) + +type Test struct { + ID int32 `json:"id"` + A1 int32 `json:"a1"` + A2 sql.NullInt32 `json:"a2"` + A3 float64 `json:"a3"` +} diff --git a/examples/type-inference-test/mysql/query.sql b/examples/type-inference-test/mysql/query.sql new file mode 100644 index 0000000000..1f1adb24ec --- /dev/null +++ b/examples/type-inference-test/mysql/query.sql @@ -0,0 +1,28 @@ +-- name: ListTest :many +SELECT + (a1 / 1024) a1_float, (a2 / 1024) a2_float, a3 +FROM test; + +-- name: ListTest2 :many +SELECT + COALESCE(CAST(a1 / 1024 AS FLOAT), 0) a1_float, COALESCE(CAST(a2 / 1024 AS FLOAT), 0) a2_float, a3 +FROM test; + +-- name: ListTest3 :many +SELECT + CAST(a1 / 1024 AS FLOAT) a1_float, CAST(a2 / 1024 AS FLOAT) a2_float, a3 +FROM test; + +-- name: ListTest4 :many +SELECT + (a1 + a2) as sum_result, + (a1 * a2) as mult_result, + (a1 - a2) as sub_result, + (a1 % 10) as mod_result +FROM test; + +-- name: ListTest5 :many +SELECT + COALESCE(a1 / 1024, 0) as with_inference, + COALESCE(a2 / 1024, 0) as nullable_inference +FROM test; diff --git a/examples/type-inference-test/mysql/query.sql.go b/examples/type-inference-test/mysql/query.sql.go new file mode 100644 index 0000000000..0deeb605ed --- /dev/null +++ b/examples/type-inference-test/mysql/query.sql.go @@ -0,0 +1,195 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package mysql + +import ( + "context" + "database/sql" +) + +const listTest = `-- name: ListTest :many +SELECT + (a1 / 1024) a1_float, (a2 / 1024) a2_float, a3 +FROM test +` + +type ListTestRow struct { + A1Float string `json:"a1_float"` + A2Float sql.NullString `json:"a2_float"` + A3 float64 `json:"a3"` +} + +func (q *Queries) ListTest(ctx context.Context) ([]ListTestRow, error) { + rows, err := q.db.QueryContext(ctx, listTest) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListTestRow + for rows.Next() { + var i ListTestRow + if err := rows.Scan(&i.A1Float, &i.A2Float, &i.A3); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listTest2 = `-- name: ListTest2 :many +SELECT + COALESCE(CAST(a1 / 1024 AS FLOAT), 0) a1_float, COALESCE(CAST(a2 / 1024 AS FLOAT), 0) a2_float, a3 +FROM test +` + +type ListTest2Row struct { + A1Float float64 `json:"a1_float"` + A2Float float64 `json:"a2_float"` + A3 float64 `json:"a3"` +} + +func (q *Queries) ListTest2(ctx context.Context) ([]ListTest2Row, error) { + rows, err := q.db.QueryContext(ctx, listTest2) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListTest2Row + for rows.Next() { + var i ListTest2Row + if err := rows.Scan(&i.A1Float, &i.A2Float, &i.A3); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listTest3 = `-- name: ListTest3 :many +SELECT + CAST(a1 / 1024 AS FLOAT) a1_float, CAST(a2 / 1024 AS FLOAT) a2_float, a3 +FROM test +` + +type ListTest3Row struct { + A1Float float64 `json:"a1_float"` + A2Float float64 `json:"a2_float"` + A3 float64 `json:"a3"` +} + +func (q *Queries) ListTest3(ctx context.Context) ([]ListTest3Row, error) { + rows, err := q.db.QueryContext(ctx, listTest3) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListTest3Row + for rows.Next() { + var i ListTest3Row + if err := rows.Scan(&i.A1Float, &i.A2Float, &i.A3); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listTest4 = `-- name: ListTest4 :many +SELECT + (a1 + a2) as sum_result, + (a1 * a2) as mult_result, + (a1 - a2) as sub_result, + (a1 % 10) as mod_result +FROM test +` + +type ListTest4Row struct { + SumResult sql.NullInt32 `json:"sum_result"` + MultResult sql.NullInt32 `json:"mult_result"` + SubResult sql.NullInt32 `json:"sub_result"` + ModResult int32 `json:"mod_result"` +} + +func (q *Queries) ListTest4(ctx context.Context) ([]ListTest4Row, error) { + rows, err := q.db.QueryContext(ctx, listTest4) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListTest4Row + for rows.Next() { + var i ListTest4Row + if err := rows.Scan( + &i.SumResult, + &i.MultResult, + &i.SubResult, + &i.ModResult, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listTest5 = `-- name: ListTest5 :many +SELECT + COALESCE(a1 / 1024, 0) as with_inference, + COALESCE(a2 / 1024, 0) as nullable_inference +FROM test +` + +type ListTest5Row struct { + WithInference string `json:"with_inference"` + NullableInference string `json:"nullable_inference"` +} + +func (q *Queries) ListTest5(ctx context.Context) ([]ListTest5Row, error) { + rows, err := q.db.QueryContext(ctx, listTest5) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListTest5Row + for rows.Next() { + var i ListTest5Row + if err := rows.Scan(&i.WithInference, &i.NullableInference); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/examples/type-inference-test/mysql/schema.sql b/examples/type-inference-test/mysql/schema.sql new file mode 100644 index 0000000000..f81cded734 --- /dev/null +++ b/examples/type-inference-test/mysql/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE test ( + id INT PRIMARY KEY AUTO_INCREMENT, + a1 INT NOT NULL, + a2 INT NULL, + a3 FLOAT NOT NULL +); diff --git a/examples/type-inference-test/sqlc.yaml b/examples/type-inference-test/sqlc.yaml new file mode 100644 index 0000000000..1d73642807 --- /dev/null +++ b/examples/type-inference-test/sqlc.yaml @@ -0,0 +1,13 @@ +version: "2" +sql: + - engine: "mysql" + queries: "mysql/query.sql" + schema: "mysql/schema.sql" + gen: + go: + package: "mysql" + out: "mysql" + emit_json_tags: true + emit_prepared_queries: false + emit_interface: false + emit_exact_table_names: false diff --git a/internal/compiler/infer_type.go b/internal/compiler/infer_type.go new file mode 100644 index 0000000000..c9a5c4bf70 --- /dev/null +++ b/internal/compiler/infer_type.go @@ -0,0 +1,234 @@ +package compiler + +import ( + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +// inferExprType recursively analyzes SQL expressions to determine their types. +// +// It handles: +// - Column references (resolved from table schema) +// - Literal constants (intrinsic types) +// - Binary operations (applies type promotion and operator-specific rules) +// - Type casts (respects explicit type annotations) +// +// Examples: +// +// SELECT a1 / 1024 -- infers decimal or float based on operand types +// SELECT COALESCE(a1 / 1024, 0) -- handles nested expressions recursively +// SELECT CAST(a1 AS INT) -- respects explicit casts +// +// Returns nil if type cannot be inferred, allowing fallback to default behavior. +func (c *Compiler) inferExprType(node ast.Node, tables []*Table) *Column { + if node == nil { + return nil + } + + // Only MySQL is supported for now - return nil for other engines + // to maintain existing behavior + if c.conf.Engine != config.EngineMySQL { + return nil + } + + switch n := node.(type) { + case *ast.ColumnRef: + // Try to resolve the column reference + // Create a minimal ResTarget for outputColumnRefs + emptyRes := &ast.ResTarget{} + cols, err := outputColumnRefs(emptyRes, tables, n) + if err != nil || len(cols) == 0 { + return nil + } + return cols[0] + + case *ast.A_Const: + // Infer type from constant value + switch n.Val.(type) { + case *ast.String: + return &Column{DataType: "text", NotNull: true} + case *ast.Integer: + return &Column{DataType: "int", NotNull: true} + case *ast.Float: + return &Column{DataType: "float", NotNull: true} + case *ast.Boolean: + return &Column{DataType: "bool", NotNull: true} + default: + return nil + } + + case *ast.A_Expr: + // Recursively infer types of left and right operands + leftCol := c.inferExprType(n.Lexpr, tables) + rightCol := c.inferExprType(n.Rexpr, tables) + + if leftCol == nil && rightCol == nil { + return nil + } + + // Extract operator name + op := "" + if n.Name != nil && len(n.Name.Items) > 0 { + if str, ok := n.Name.Items[0].(*ast.String); ok { + op = str.Str + } + } + + // Apply database-specific type rules + return c.combineTypes(leftCol, rightCol, op) + + case *ast.TypeCast: + // If there's an explicit cast, use that type + if n.TypeName != nil { + col := toColumn(n.TypeName) + // Check if the casted value is nullable + if constant, ok := n.Arg.(*ast.A_Const); ok { + if _, isNull := constant.Val.(*ast.Null); isNull { + col.NotNull = false + } + } + return col + } + } + + return nil +} + +// combineTypes determines the result type of a binary operation. +// The logic is database-specific and handles operator semantics for each engine. +func (c *Compiler) combineTypes(left, right *Column, operator string) *Column { + // If either operand is unknown, we can't infer the type + if left == nil && right == nil { + return nil + } + + // If one operand is known, use it as a hint + if left == nil { + return right + } + if right == nil { + return left + } + + // Result is nullable if any operand is nullable (SQL NULL propagation rule) + notNull := left.NotNull && right.NotNull + + // Apply database-specific type rules + switch c.conf.Engine { + case config.EngineMySQL: + return combineMySQLTypes(left, right, operator, notNull) + default: + // TODO: Implement type inference for PostgreSQL and SQLite + // For now, use conservative fallback rules + return combineGenericTypes(left, right, operator, notNull) + } +} + +// combineMySQLTypes implements MySQL-specific type inference rules. +// +// Division always returns decimal or float: +// +// SELECT int_col / 1024 -- returns decimal +// SELECT float_col / 1024 -- returns float +// SELECT int_col DIV 1024 -- returns decimal (DIV is MySQL-specific) +// +// Nullability propagates (NULL op anything = NULL): +// +// nullable_col / 1024 -- returns nullable result +// NOT NULL col / 1024 -- returns NOT NULL result +func combineMySQLTypes(left, right *Column, operator string, notNull bool) *Column { + // Handle nil operands + if left == nil && right == nil { + return nil + } + if left == nil { + return right + } + if right == nil { + return left + } + + switch operator { + case "/", "div": + // Division: int/int = decimal, float/anything = float + // Note: "div" is MySQL-specific operator recognized by IsMathematicalOperator() + if isFloatType(left.DataType) || isFloatType(right.DataType) { + return &Column{DataType: "float", NotNull: notNull} + } + return &Column{DataType: "decimal", NotNull: notNull} + + case "%", "mod": + // Modulo: returns integer if both are integers, otherwise decimal + // Note: "mod" is MySQL-specific operator recognized by IsMathematicalOperator() + if isIntegerType(left.DataType) && isIntegerType(right.DataType) { + return &Column{DataType: "int", NotNull: notNull} + } + return &Column{DataType: "decimal", NotNull: notNull} + + case "&", "|", "<<", ">>", "~", "#", "^": + // Bitwise operators: always integer + return &Column{DataType: "int", NotNull: notNull} + + default: + // Arithmetic: standard type promotion + return promoteArithmeticTypes(left, right, notNull) + } +} + +// combineGenericTypes provides conservative fallback for unsupported databases. +// Returns nil to avoid incorrect type assumptions for PostgreSQL, SQLite, etc. +// This ensures fallback to the original behavior (interface{} or default types). +func combineGenericTypes(left, right *Column, operator string, notNull bool) *Column { + // TODO: Implement type inference for PostgreSQL and SQLite + // For now, return nil to use safe defaults + return nil +} + +// promoteArithmeticTypes applies standard type promotion rules for arithmetic. +// This follows the principle: int -> decimal -> float +func promoteArithmeticTypes(left, right *Column, notNull bool) *Column { + dataType := "int" + + // Float takes precedence over all other numeric types + if isFloatType(left.DataType) || isFloatType(right.DataType) { + dataType = "float" + } else if isDecimalType(left.DataType) || isDecimalType(right.DataType) { + // Decimal takes precedence over integer + dataType = "decimal" + } else if isIntegerType(left.DataType) && isIntegerType(right.DataType) { + // Both are integers + dataType = "int" + } + + return &Column{ + DataType: dataType, + NotNull: notNull, + } +} + +// isFloatType checks if a datatype is a floating-point type +func isFloatType(dataType string) bool { + switch dataType { + case "float", "double", "double precision", "real": + return true + } + return false +} + +// isDecimalType checks if a datatype is a decimal/numeric type +func isDecimalType(dataType string) bool { + switch dataType { + case "decimal", "numeric", "dec", "fixed": + return true + } + return false +} + +// isIntegerType checks if a datatype is an integer type +func isIntegerType(dataType string) bool { + switch dataType { + case "int", "integer", "smallint", "bigint", "tinyint", "mediumint": + return true + } + return false +} diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index b0a15e6ac4..bbef40657d 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -150,12 +150,20 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er if res.Name != nil { name = *res.Name } - switch op := astutils.Join(n.Name, ""); { + op := astutils.Join(n.Name, "") + switch { case lang.IsComparisonOperator(op): // TODO: Generate a name for these operations cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true}) case lang.IsMathematicalOperator(op): - cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true}) + // Try to infer the type from operands + if inferredCol := c.inferExprType(n, tables); inferredCol != nil { + inferredCol.Name = name + cols = append(cols, inferredCol) + } else { + // Fallback to previous behavior if inference fails + cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true}) + } default: cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false}) } @@ -233,6 +241,17 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er shouldNotBeNull = true continue } + + // Try to infer the type of the argument + if inferredCol := c.inferExprType(arg, tables); inferredCol != nil { + if firstColumn == nil { + firstColumn = inferredCol + } + shouldNotBeNull = shouldNotBeNull || inferredCol.NotNull + continue + } + + // Fallback to the old logic for simple column references if ref, ok := arg.(*ast.ColumnRef); ok { columns, err := outputColumnRefs(res, tables, ref) if err != nil { @@ -247,6 +266,7 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er } } if firstColumn != nil { + firstColumn.Name = name firstColumn.NotNull = shouldNotBeNull firstColumn.skipTableRequiredCheck = true cols = append(cols, firstColumn) diff --git a/internal/sql/lang/operator.go b/internal/sql/lang/operator.go index cd5ef50e38..1c6d507ec8 100644 --- a/internal/sql/lang/operator.go +++ b/internal/sql/lang/operator.go @@ -23,6 +23,8 @@ func IsMathematicalOperator(s string) bool { case "-": case "*": case "/": + case "div": // TODO: MySQL-specific operator - should be moved to engine-specific logic + case "mod": // TODO: MySQL-specific operator - should be moved to engine-specific logic case "%": case "^": case "|/":