|
| 1 | +package compiler |
| 2 | + |
| 3 | +import ( |
| 4 | + "github.com/sqlc-dev/sqlc/internal/config" |
| 5 | + "github.com/sqlc-dev/sqlc/internal/sql/ast" |
| 6 | +) |
| 7 | + |
| 8 | +// inferExprType recursively analyzes SQL expressions to determine their types. |
| 9 | +// |
| 10 | +// It handles: |
| 11 | +// - Column references (resolved from table schema) |
| 12 | +// - Literal constants (intrinsic types) |
| 13 | +// - Binary operations (applies type promotion and operator-specific rules) |
| 14 | +// - Type casts (respects explicit type annotations) |
| 15 | +// |
| 16 | +// Examples: |
| 17 | +// |
| 18 | +// SELECT a1 / 1024 -- infers decimal or float based on operand types |
| 19 | +// SELECT COALESCE(a1 / 1024, 0) -- handles nested expressions recursively |
| 20 | +// SELECT CAST(a1 AS INT) -- respects explicit casts |
| 21 | +// |
| 22 | +// Returns nil if type cannot be inferred, allowing fallback to default behavior. |
| 23 | +func (c *Compiler) inferExprType(node ast.Node, tables []*Table) *Column { |
| 24 | + if node == nil { |
| 25 | + return nil |
| 26 | + } |
| 27 | + |
| 28 | + switch n := node.(type) { |
| 29 | + case *ast.ColumnRef: |
| 30 | + // Try to resolve the column reference |
| 31 | + // Create a minimal ResTarget for outputColumnRefs |
| 32 | + emptyRes := &ast.ResTarget{} |
| 33 | + cols, err := outputColumnRefs(emptyRes, tables, n) |
| 34 | + if err != nil || len(cols) == 0 { |
| 35 | + return nil |
| 36 | + } |
| 37 | + return cols[0] |
| 38 | + |
| 39 | + case *ast.A_Const: |
| 40 | + // Infer type from constant value |
| 41 | + switch n.Val.(type) { |
| 42 | + case *ast.String: |
| 43 | + return &Column{DataType: "text", NotNull: true} |
| 44 | + case *ast.Integer: |
| 45 | + return &Column{DataType: "int", NotNull: true} |
| 46 | + case *ast.Float: |
| 47 | + return &Column{DataType: "float", NotNull: true} |
| 48 | + case *ast.Boolean: |
| 49 | + return &Column{DataType: "bool", NotNull: true} |
| 50 | + default: |
| 51 | + return nil |
| 52 | + } |
| 53 | + |
| 54 | + case *ast.A_Expr: |
| 55 | + // Recursively infer types of left and right operands |
| 56 | + leftCol := c.inferExprType(n.Lexpr, tables) |
| 57 | + rightCol := c.inferExprType(n.Rexpr, tables) |
| 58 | + |
| 59 | + if leftCol == nil && rightCol == nil { |
| 60 | + return nil |
| 61 | + } |
| 62 | + |
| 63 | + // Extract operator name |
| 64 | + op := "" |
| 65 | + if n.Name != nil && len(n.Name.Items) > 0 { |
| 66 | + if str, ok := n.Name.Items[0].(*ast.String); ok { |
| 67 | + op = str.Str |
| 68 | + } |
| 69 | + } |
| 70 | + |
| 71 | + // Apply database-specific type rules |
| 72 | + return c.combineTypes(leftCol, rightCol, op) |
| 73 | + |
| 74 | + case *ast.TypeCast: |
| 75 | + // If there's an explicit cast, use that type |
| 76 | + if n.TypeName != nil { |
| 77 | + col := toColumn(n.TypeName) |
| 78 | + // Check if the casted value is nullable |
| 79 | + if constant, ok := n.Arg.(*ast.A_Const); ok { |
| 80 | + if _, isNull := constant.Val.(*ast.Null); isNull { |
| 81 | + col.NotNull = false |
| 82 | + } |
| 83 | + } |
| 84 | + return col |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + return nil |
| 89 | +} |
| 90 | + |
| 91 | +// combineTypes determines the result type of a binary operation. |
| 92 | +// The logic is database-specific and handles operator semantics for each engine. |
| 93 | +func (c *Compiler) combineTypes(left, right *Column, operator string) *Column { |
| 94 | + // If either operand is unknown, we can't infer the type |
| 95 | + if left == nil && right == nil { |
| 96 | + return nil |
| 97 | + } |
| 98 | + |
| 99 | + // If one operand is known, use it as a hint |
| 100 | + if left == nil { |
| 101 | + return right |
| 102 | + } |
| 103 | + if right == nil { |
| 104 | + return left |
| 105 | + } |
| 106 | + |
| 107 | + // Result is nullable if any operand is nullable (SQL NULL propagation rule) |
| 108 | + notNull := left.NotNull && right.NotNull |
| 109 | + |
| 110 | + // Apply database-specific type rules |
| 111 | + switch c.conf.Engine { |
| 112 | + case config.EngineMySQL: |
| 113 | + return combineMySQLTypes(left, right, operator, notNull) |
| 114 | + default: |
| 115 | + // TODO: Implement type inference for PostgreSQL and SQLite |
| 116 | + // For now, use conservative fallback rules |
| 117 | + return combineGenericTypes(left, right, operator, notNull) |
| 118 | + } |
| 119 | +} |
| 120 | + |
| 121 | +// combineMySQLTypes implements MySQL-specific type inference rules. |
| 122 | +// |
| 123 | +// Division always returns decimal or float: |
| 124 | +// |
| 125 | +// SELECT int_col / 1024 -- returns decimal |
| 126 | +// SELECT float_col / 1024 -- returns float |
| 127 | +// SELECT int_col DIV 1024 -- returns decimal (DIV is MySQL-specific) |
| 128 | +// |
| 129 | +// Nullability propagates (NULL op anything = NULL): |
| 130 | +// |
| 131 | +// nullable_col / 1024 -- returns nullable result |
| 132 | +// NOT NULL col / 1024 -- returns NOT NULL result |
| 133 | +func combineMySQLTypes(left, right *Column, operator string, notNull bool) *Column { |
| 134 | + // Handle nil operands |
| 135 | + if left == nil && right == nil { |
| 136 | + return nil |
| 137 | + } |
| 138 | + if left == nil { |
| 139 | + return right |
| 140 | + } |
| 141 | + if right == nil { |
| 142 | + return left |
| 143 | + } |
| 144 | + |
| 145 | + // Normalize MySQL-specific operators |
| 146 | + switch operator { |
| 147 | + case "div": |
| 148 | + operator = "/" // DIV is MySQL's integer division, but for type inference treat as division |
| 149 | + case "mod": |
| 150 | + operator = "%" // MOD is alias for % |
| 151 | + } |
| 152 | + |
| 153 | + switch operator { |
| 154 | + case "/": |
| 155 | + // Division: int/int = decimal, float/anything = float |
| 156 | + if isFloatType(left.DataType) || isFloatType(right.DataType) { |
| 157 | + return &Column{DataType: "float", NotNull: notNull} |
| 158 | + } |
| 159 | + return &Column{DataType: "decimal", NotNull: notNull} |
| 160 | + |
| 161 | + case "%": |
| 162 | + // Modulo: returns integer if both are integers, otherwise decimal |
| 163 | + if isIntegerType(left.DataType) && isIntegerType(right.DataType) { |
| 164 | + return &Column{DataType: "int", NotNull: notNull} |
| 165 | + } |
| 166 | + return &Column{DataType: "decimal", NotNull: notNull} |
| 167 | + |
| 168 | + case "&", "|", "<<", ">>", "~", "#", "^": |
| 169 | + // Bitwise operators: always integer |
| 170 | + return &Column{DataType: "int", NotNull: notNull} |
| 171 | + |
| 172 | + default: |
| 173 | + // Arithmetic: standard type promotion |
| 174 | + return promoteArithmeticTypes(left, right, notNull) |
| 175 | + } |
| 176 | +} |
| 177 | + |
| 178 | +// combineGenericTypes provides conservative fallback for unsupported databases. |
| 179 | +// Returns nil to avoid incorrect type assumptions for PostgreSQL, SQLite, etc. |
| 180 | +// This ensures fallback to the original behavior (interface{} or default types). |
| 181 | +func combineGenericTypes(left, right *Column, operator string, notNull bool) *Column { |
| 182 | + // TODO: Implement type inference for PostgreSQL and SQLite |
| 183 | + // For now, return nil to use safe defaults |
| 184 | + return nil |
| 185 | +} |
| 186 | + |
| 187 | +// promoteArithmeticTypes applies standard type promotion rules for arithmetic. |
| 188 | +// This follows the principle: int -> decimal -> float |
| 189 | +func promoteArithmeticTypes(left, right *Column, notNull bool) *Column { |
| 190 | + dataType := "int" |
| 191 | + |
| 192 | + // Float takes precedence over all other numeric types |
| 193 | + if isFloatType(left.DataType) || isFloatType(right.DataType) { |
| 194 | + dataType = "float" |
| 195 | + } else if isDecimalType(left.DataType) || isDecimalType(right.DataType) { |
| 196 | + // Decimal takes precedence over integer |
| 197 | + dataType = "decimal" |
| 198 | + } else if isIntegerType(left.DataType) && isIntegerType(right.DataType) { |
| 199 | + // Both are integers |
| 200 | + dataType = "int" |
| 201 | + } |
| 202 | + |
| 203 | + return &Column{ |
| 204 | + DataType: dataType, |
| 205 | + NotNull: notNull, |
| 206 | + } |
| 207 | +} |
| 208 | + |
| 209 | +// isFloatType checks if a datatype is a floating-point type |
| 210 | +func isFloatType(dataType string) bool { |
| 211 | + switch dataType { |
| 212 | + case "float", "double", "double precision", "real": |
| 213 | + return true |
| 214 | + } |
| 215 | + return false |
| 216 | +} |
| 217 | + |
| 218 | +// isDecimalType checks if a datatype is a decimal/numeric type |
| 219 | +func isDecimalType(dataType string) bool { |
| 220 | + switch dataType { |
| 221 | + case "decimal", "numeric", "dec", "fixed": |
| 222 | + return true |
| 223 | + } |
| 224 | + return false |
| 225 | +} |
| 226 | + |
| 227 | +// isIntegerType checks if a datatype is an integer type |
| 228 | +func isIntegerType(dataType string) bool { |
| 229 | + switch dataType { |
| 230 | + case "int", "integer", "smallint", "bigint", "tinyint", "mediumint": |
| 231 | + return true |
| 232 | + } |
| 233 | + return false |
| 234 | +} |
0 commit comments