Skip to content

Commit a680b1b

Browse files
authored
Massive Optional + Functions update (#13)
* Massive Functions update * Fixed ydb_type nullable problem and added new funcs to ydb catalog * Update internal/engine/ydb/convert.go * Removed comment from query.go to extractBaseType method
1 parent 1379e43 commit a680b1b

File tree

23 files changed

+5114
-488
lines changed

23 files changed

+5114
-488
lines changed

internal/codegen/golang/query.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,9 @@ func (v QueryValue) YDBParamMapEntries() string {
297297

298298
// ydbBuilderMethodForColumnType maps a YDB column data type to a ParamsBuilder method name.
299299
func ydbBuilderMethodForColumnType(dbType string) string {
300-
switch strings.ToLower(dbType) {
300+
baseType := extractBaseType(strings.ToLower(dbType))
301+
302+
switch baseType {
301303
case "bool":
302304
return "Bool"
303305
case "uint64":

internal/codegen/golang/ydb_type.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ import (
1212

1313
func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
1414
columnType := strings.ToLower(sdk.DataType(col.Type))
15-
notNull := col.NotNull || col.IsArray
15+
notNull := (col.NotNull || col.IsArray) && !isNullableType(columnType)
1616
emitPointersForNull := options.EmitPointersForNullTypes
1717

18+
columnType = extractBaseType(columnType)
19+
1820
// https://ydb.tech/docs/ru/yql/reference/types/
1921
// ydb-go-sdk doesn't support sql.Null* yet
2022
switch columnType {
@@ -49,7 +51,7 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col
4951
}
5052
// return "sql.NullInt16"
5153
return "*int16"
52-
case "int", "int32": //ydb doesn't have int type, but we need it to support untyped constants
54+
case "int", "int32": //ydb doesn't have int type, but we need it to support untyped constants
5355
if notNull {
5456
return "int32"
5557
}
@@ -159,7 +161,7 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col
159161
return "*string"
160162
}
161163
return "*string"
162-
164+
163165
case "date", "date32", "datetime", "timestamp", "tzdate", "tztimestamp", "tzdatetime":
164166
if notNull {
165167
return "time.Time"
@@ -185,3 +187,18 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col
185187
}
186188

187189
}
190+
191+
// This function extracts the base type from optional types
192+
func extractBaseType(typeStr string) string {
193+
if strings.HasPrefix(typeStr, "optional<") && strings.HasSuffix(typeStr, ">") {
194+
return strings.TrimSuffix(strings.TrimPrefix(typeStr, "optional<"), ">")
195+
}
196+
if strings.HasSuffix(typeStr, "?") {
197+
return strings.TrimSuffix(typeStr, "?")
198+
}
199+
return typeStr
200+
}
201+
202+
func isNullableType(typeStr string) bool {
203+
return strings.HasPrefix(typeStr, "optional<") && strings.HasSuffix(typeStr, ">") || strings.HasSuffix(typeStr, "?")
204+
}

internal/engine/ydb/convert.go

Lines changed: 132 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ydb
22

33
import (
4+
"fmt"
45
"log"
56
"strconv"
67
"strings"
@@ -1787,7 +1788,15 @@ func (c *cc) VisitType_name_or_bind(n *parser.Type_name_or_bindContext) interfac
17871788
}
17881789
return typeName
17891790
} else if b := n.Bind_parameter(); b != nil {
1790-
return &ast.TypeName{Name: "BIND:" + identifier(parseAnIdOrType(b.An_id_or_type()))}
1791+
param, ok := b.Accept(c).(ast.Node)
1792+
if !ok {
1793+
return todo("VisitType_name_or_bind", b)
1794+
}
1795+
return &ast.TypeName{
1796+
Names: &ast.List{
1797+
Items: []ast.Node{param},
1798+
},
1799+
}
17911800
}
17921801
return todo("VisitType_name_or_bind", n)
17931802
}
@@ -1797,6 +1806,8 @@ func (c *cc) VisitType_name(n *parser.Type_nameContext) interface{} {
17971806
return todo("VisitType_name", n)
17981807
}
17991808

1809+
questionCount := len(n.AllQUESTION())
1810+
18001811
if composite := n.Type_name_composite(); composite != nil {
18011812
typeName, ok := composite.Accept(c).(ast.Node)
18021813
if !ok {
@@ -1815,8 +1826,12 @@ func (c *cc) VisitType_name(n *parser.Type_nameContext) interface{} {
18151826
if !ok {
18161827
return todo("VisitType_name", decimal.Integer_or_bind(1))
18171828
}
1829+
name := "decimal"
1830+
if questionCount > 0 {
1831+
name = name + "?"
1832+
}
18181833
return &ast.TypeName{
1819-
Name: "decimal",
1834+
Name: name,
18201835
TypeOid: 0,
18211836
Names: &ast.List{
18221837
Items: []ast.Node{
@@ -1829,12 +1844,17 @@ func (c *cc) VisitType_name(n *parser.Type_nameContext) interface{} {
18291844
}
18301845

18311846
if simple := n.Type_name_simple(); simple != nil {
1847+
name := simple.GetText()
1848+
if questionCount > 0 {
1849+
name = name + "?"
1850+
}
18321851
return &ast.TypeName{
1833-
Name: simple.GetText(),
1852+
Name: name,
18341853
TypeOid: 0,
18351854
}
18361855
}
18371856

1857+
// todo: handle multiple ? suffixes
18381858
return todo("VisitType_name", n)
18391859
}
18401860

@@ -1868,19 +1888,7 @@ func (c *cc) VisitType_name_composite(n *parser.Type_name_compositeContext) inte
18681888
}
18691889

18701890
if opt := n.Type_name_optional(); opt != nil {
1871-
if typeName := opt.Type_name_or_bind(); typeName != nil {
1872-
tn, ok := typeName.Accept(c).(ast.Node)
1873-
if !ok {
1874-
return todo("VisitType_name_composite", typeName)
1875-
}
1876-
return &ast.TypeName{
1877-
Name: "Optional",
1878-
TypeOid: 0,
1879-
Names: &ast.List{
1880-
Items: []ast.Node{tn},
1881-
},
1882-
}
1883-
}
1891+
return opt.Accept(c)
18841892
}
18851893

18861894
if tuple := n.Type_name_tuple(); tuple != nil {
@@ -2025,6 +2033,27 @@ func (c *cc) VisitType_name_composite(n *parser.Type_name_compositeContext) inte
20252033
return todo("VisitType_name_composite", n)
20262034
}
20272035

2036+
func (c *cc) VisitType_name_optional(n *parser.Type_name_optionalContext) interface{} {
2037+
if n == nil || n.Type_name_or_bind() == nil {
2038+
return todo("VisitType_name_optional", n)
2039+
}
2040+
2041+
tn, ok := n.Type_name_or_bind().Accept(c).(ast.Node)
2042+
if !ok {
2043+
return todo("VisitType_name_optional", n.Type_name_or_bind())
2044+
}
2045+
innerTypeName, ok := tn.(*ast.TypeName)
2046+
if !ok {
2047+
return todo("VisitType_name_optional", n.Type_name_or_bind())
2048+
}
2049+
name := fmt.Sprintf("Optional<%s>", innerTypeName.Name)
2050+
return &ast.TypeName{
2051+
Name: name,
2052+
TypeOid: 0,
2053+
Names: &ast.List{},
2054+
}
2055+
}
2056+
20282057
func (c *cc) VisitSql_stmt_core(n *parser.Sql_stmt_coreContext) interface{} {
20292058
if n == nil {
20302059
return todo("VisitSql_stmt_core", n)
@@ -2799,13 +2828,28 @@ func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprCont
27992828
}
28002829
funcName := strings.Join(nameParts, ".")
28012830

2802-
if funcName == "coalesce" {
2831+
if funcName == "coalesce" || funcName == "nvl" {
28032832
return &ast.CoalesceExpr{
28042833
Args: funcCall.Args,
28052834
Location: baseNode.Location,
28062835
}
28072836
}
28082837

2838+
if funcName == "greatest" || funcName == "max_of" {
2839+
return &ast.MinMaxExpr{
2840+
Op: ast.MinMaxOp(1),
2841+
Args: funcCall.Args,
2842+
Location: baseNode.Location,
2843+
}
2844+
}
2845+
if funcName == "least" || funcName == "min_of" {
2846+
return &ast.MinMaxExpr{
2847+
Op: ast.MinMaxOp(2),
2848+
Args: funcCall.Args,
2849+
Location: baseNode.Location,
2850+
}
2851+
}
2852+
28092853
funcCall.Func = &ast.FuncName{Name: funcName}
28102854
funcCall.Funcname.Items = append(funcCall.Funcname.Items, &ast.String{Str: funcName})
28112855

@@ -2816,15 +2860,12 @@ func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprCont
28162860
}
28172861
}
28182862

2819-
stmt := &ast.RecursiveFuncCall{
2820-
Func: base,
2821-
Funcname: funcCall.Funcname,
2822-
AggStar: funcCall.AggStar,
2823-
Location: funcCall.Location,
2824-
Args: funcCall.Args,
2825-
AggDistinct: funcCall.AggDistinct,
2863+
stmt := &ast.FuncExpr{
2864+
Xpr: base,
2865+
Args: funcCall.Args,
2866+
Location: funcCall.Location,
28262867
}
2827-
stmt.Funcname.Items = append(stmt.Funcname.Items, base)
2868+
28282869
return stmt
28292870
}
28302871

@@ -2943,16 +2984,42 @@ func (c *cc) VisitId_expr(n *parser.Id_exprContext) interface{} {
29432984
if n == nil {
29442985
return todo("VisitId_expr", n)
29452986
}
2987+
2988+
ref := &ast.ColumnRef{
2989+
Fields: &ast.List{},
2990+
Location: c.pos(n.GetStart()),
2991+
}
2992+
29462993
if id := n.Identifier(); id != nil {
2947-
return &ast.ColumnRef{
2948-
Fields: &ast.List{
2949-
Items: []ast.Node{
2950-
NewIdentifier(id.GetText()),
2951-
},
2952-
},
2953-
Location: c.pos(id.GetStart()),
2954-
}
2994+
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(id.GetText()))
2995+
return ref
2996+
}
2997+
2998+
if keyword := n.Keyword_compat(); keyword != nil {
2999+
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText()))
3000+
return ref
3001+
}
3002+
3003+
if keyword := n.Keyword_alter_uncompat(); keyword != nil {
3004+
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText()))
3005+
return ref
3006+
}
3007+
3008+
if keyword := n.Keyword_in_uncompat(); keyword != nil {
3009+
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText()))
3010+
return ref
3011+
}
3012+
3013+
if keyword := n.Keyword_window_uncompat(); keyword != nil {
3014+
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText()))
3015+
return ref
3016+
}
3017+
3018+
if keyword := n.Keyword_hint_uncompat(); keyword != nil {
3019+
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText()))
3020+
return ref
29553021
}
3022+
29563023
return todo("VisitId_expr", n)
29573024
}
29583025

@@ -2979,12 +3046,44 @@ func (c *cc) VisitAtom_expr(n *parser.Atom_exprContext) interface{} {
29793046
return todo("VisitAtom_expr", n.Bind_parameter())
29803047
}
29813048
return expr
3049+
case n.Cast_expr() != nil:
3050+
expr, ok := n.Cast_expr().Accept(c).(ast.Node)
3051+
if !ok {
3052+
return todo("VisitAtom_expr", n.Cast_expr())
3053+
}
3054+
return expr
29823055
// TODO: check other cases
29833056
default:
29843057
return todo("VisitAtom_expr", n)
29853058
}
29863059
}
29873060

3061+
func (c *cc) VisitCast_expr(n *parser.Cast_exprContext) interface{} {
3062+
if n == nil || n.CAST() == nil || n.Expr() == nil || n.AS() == nil || n.Type_name_or_bind() == nil {
3063+
return todo("VisitCast_expr", n)
3064+
}
3065+
3066+
expr, ok := n.Expr().Accept(c).(ast.Node)
3067+
if !ok {
3068+
return todo("VisitCast_expr", n.Expr())
3069+
}
3070+
3071+
temp, ok := n.Type_name_or_bind().Accept(c).(ast.Node)
3072+
if !ok {
3073+
return todo("VisitCast_expr", n.Type_name_or_bind())
3074+
}
3075+
typeName, ok := temp.(*ast.TypeName)
3076+
if !ok {
3077+
return todo("VisitCast_expr", n.Type_name_or_bind())
3078+
}
3079+
3080+
return &ast.TypeCast{
3081+
Arg: expr,
3082+
TypeName: typeName,
3083+
Location: c.pos(n.GetStart()),
3084+
}
3085+
}
3086+
29883087
func (c *cc) VisitLiteral_value(n *parser.Literal_valueContext) interface{} {
29893088
if n == nil {
29903089
return todo("VisitLiteral_value", n)

0 commit comments

Comments
 (0)