Skip to content

Commit 351884f

Browse files
author
James Cor
committed
refactoring and fixing tests
1 parent 5694c35 commit 351884f

File tree

11 files changed

+136
-98
lines changed

11 files changed

+136
-98
lines changed

server/handler.go

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ func (h *Handler) doQuery(
496496
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
497497
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf)
498498
} else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) {
499-
r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more)
499+
r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, schema, ri2, resultFields, buf, callback, more)
500500
} else {
501501
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf)
502502
}
@@ -770,14 +770,13 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
770770
return r, processedAtLeastOneBatch, nil
771771
}
772772

773-
func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) {
773+
func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter2, resultFields []*querypb.Field, buf *sql.ByteBuffer, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) {
774774
defer trace.StartRegion(ctx, "Handler.resultForDefaultIter2").End()
775775

776776
eg, ctx := ctx.NewErrgroup()
777777
pan2err := func(err *error) {
778778
if recoveredPanic := recover(); recoveredPanic != nil {
779-
stack := debug.Stack()
780-
wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, stack)
779+
wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, debug.Stack())
781780
*err = goerrors.Join(*err, wrappedErr)
782781
}
783782
}
@@ -868,24 +867,9 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq
868867
if !ok {
869868
return nil
870869
}
871-
resRow := make([]sqltypes.Value, len(row))
872-
for i, v := range row {
873-
if v.Val != nil || v.WrappedVal == nil {
874-
resRow[i] = sqltypes.MakeTrusted(v.Typ, v.Val)
875-
continue
876-
}
877-
dVal, err := v.WrappedVal.UnwrapAny(ctx)
878-
if err != nil {
879-
return err
880-
}
881-
switch dVal := dVal.(type) {
882-
case []byte:
883-
resRow[i] = sqltypes.MakeTrusted(v.Typ, dVal)
884-
case string:
885-
resRow[i] = sqltypes.MakeTrusted(v.Typ, []byte(dVal))
886-
default:
887-
panic(fmt.Sprintf("unexpected type %T", dVal))
888-
}
870+
resRow, err := RowValueToSQLValues(ctx, schema, row, buf)
871+
if err != nil {
872+
return err
889873
}
890874
ctx.GetLogger().Tracef("spooling result row %s", resRow)
891875
res.Rows = append(res.Rows, resRow)
@@ -1187,6 +1171,35 @@ func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Express
11871171
return outVals, nil
11881172
}
11891173

1174+
func RowValueToSQLValues(ctx *sql.Context, sch sql.Schema, row sql.ValueRow, buf *sql.ByteBuffer) ([]sqltypes.Value, error) {
1175+
if len(sch) == 0 {
1176+
return []sqltypes.Value{}, nil
1177+
}
1178+
var err error
1179+
outVals := make([]sqltypes.Value, len(sch))
1180+
for i, col := range sch {
1181+
// TODO: remove this check once all Types implement this
1182+
valType, ok := col.Type.(sql.Type2)
1183+
if !ok {
1184+
outVals[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val)
1185+
continue
1186+
}
1187+
if buf == nil {
1188+
outVals[i], err = valType.ToSQLValue(ctx, row[i], nil)
1189+
if err != nil {
1190+
return nil, err
1191+
}
1192+
continue
1193+
}
1194+
outVals[i], err = valType.ToSQLValue(ctx, row[i], buf.Get())
1195+
if err != nil {
1196+
return nil, err
1197+
}
1198+
buf.Grow(outVals[i].Len())
1199+
}
1200+
return outVals, nil
1201+
}
1202+
11901203
func schemaToFields(ctx *sql.Context, s sql.Schema) []*querypb.Field {
11911204
charSetResults := ctx.GetCharacterSetResults()
11921205
fields := make([]*querypb.Field, len(s))

sql/expression/sort.go

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -102,56 +102,6 @@ func (s *Sorter2) Swap(i, j int) {
102102
s.Rows[i], s.Rows[j] = s.Rows[j], s.Rows[i]
103103
}
104104

105-
func (s *Sorter2) Less(i, j int) bool {
106-
if s.LastError != nil {
107-
return false
108-
}
109-
110-
a := s.Rows[i]
111-
b := s.Rows[j]
112-
for _, sf := range s.SortFields {
113-
typ := sf.Column2.Type2()
114-
av, err := sf.Column2.Eval2(s.Ctx, a)
115-
if err != nil {
116-
s.LastError = sql.ErrUnableSort.Wrap(err)
117-
return false
118-
}
119-
120-
bv, err := sf.Column2.Eval2(s.Ctx, b)
121-
if err != nil {
122-
s.LastError = sql.ErrUnableSort.Wrap(err)
123-
return false
124-
}
125-
126-
if sf.Order == sql.Descending {
127-
av, bv = bv, av
128-
}
129-
130-
if av.IsNull() && bv.IsNull() {
131-
continue
132-
} else if av.IsNull() {
133-
return sf.NullOrdering == sql.NullsFirst
134-
} else if bv.IsNull() {
135-
return sf.NullOrdering != sql.NullsFirst
136-
}
137-
138-
cmp, err := typ.Compare2(av, bv)
139-
if err != nil {
140-
s.LastError = err
141-
return false
142-
}
143-
144-
switch cmp {
145-
case -1:
146-
return true
147-
case 1:
148-
return false
149-
}
150-
}
151-
152-
return false
153-
}
154-
155105
// TopRowsHeap implements heap.Interface based on Sorter. It inverts the Less()
156106
// function so that it can be used to implement TopN. heap.Push() rows into it,
157107
// and if Len() > MAX; heap.Pop() the current min row. Then, at the end of

sql/row_frame.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ type ValueBytes []byte
3030
// Value is a logical index into a ValueRow. For efficiency reasons, use sparingly.
3131
type Value struct {
3232
Val ValueBytes
33-
WrappedVal AnyWrapper
33+
WrappedVal BytesWrapper
3434
Typ querypb.Type // TODO: consider sqltypes.Type instead
3535
}
3636

sql/type.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,12 +294,7 @@ func IsDecimalType(t Type) bool {
294294

295295
type Type2 interface {
296296
Type
297-
// Compare2 returns an integer comparing two Values.
298-
Compare2(Value, Value) (int, error)
299-
// Convert2 converts a value of a compatible type.
300-
Convert2(Value) (Value, error)
301-
// Zero2 returns the zero Value for this type.
302-
Zero2() Value
297+
ToSQLValue(*Context, Value, []byte) (sqltypes.Value, error)
303298
}
304299

305300
// SpatialColumnType is a node that contains a reference to all spatial types.

sql/types/bit.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ import (
1818
"context"
1919
"encoding/binary"
2020
"fmt"
21+
"github.com/dolthub/go-mysql-server/sql/values"
2122
"reflect"
23+
"strconv"
2224

2325
"github.com/dolthub/vitess/go/sqltypes"
2426
"github.com/dolthub/vitess/go/vt/proto/query"
@@ -211,6 +213,17 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
211213
return sqltypes.MakeTrusted(sqltypes.Bit, val), nil
212214
}
213215

216+
// ToSQLValue implements Type2 interface.
217+
func (t BitType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) {
218+
if v.IsNull() {
219+
return sqltypes.NULL, nil
220+
}
221+
// Assume this is uint64
222+
x := values.ReadUint64(v.Val)
223+
dest = strconv.AppendUint(dest, x, 10)
224+
return sqltypes.MakeTrusted(sqltypes.Bit, dest), nil
225+
}
226+
214227
// String implements Type interface.
215228
func (t BitType_) String() string {
216229
return fmt.Sprintf("bit(%v)", t.numOfBits)

sql/types/datetime.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package types
1717
import (
1818
"context"
1919
"fmt"
20+
"github.com/dolthub/go-mysql-server/sql/values"
2021
"math"
2122
"reflect"
2223
"time"
@@ -474,6 +475,31 @@ func (t datetimeType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltype
474475
return sqltypes.MakeTrusted(typ, valBytes), nil
475476
}
476477

478+
func (t datetimeType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) {
479+
if v.IsNull() {
480+
return sqltypes.NULL, nil
481+
}
482+
switch t.baseType {
483+
case sqltypes.Date:
484+
// TODO: move this to values
485+
x := values.ReadUint32(v.Val)
486+
y := x >> 16
487+
m := (x & (255 << 8)) >> 8
488+
d := x & 255
489+
t := time.Date(int(y), time.Month(m), int(d), 0, 0, 0, 0, time.UTC)
490+
dest = t.AppendFormat(dest, sql.TimestampDatetimeLayout)
491+
492+
case sqltypes.Datetime, sqltypes.Timestamp:
493+
x := values.ReadInt64(v.Val)
494+
t := time.UnixMicro(x).UTC()
495+
dest = t.AppendFormat(dest, sql.TimestampDatetimeLayout)
496+
497+
default:
498+
return sqltypes.Value{}, sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime")
499+
}
500+
return sqltypes.MakeTrusted(t.baseType, dest), nil
501+
}
502+
477503
func (t datetimeType) String() string {
478504
switch t.baseType {
479505
case sqltypes.Date:

sql/types/number.go

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -867,55 +867,54 @@ func (t NumberTypeImpl_) Zero2() sql.Value {
867867
}
868868
}
869869

870-
// SQL2 implements Type2 interface.
871-
func (t NumberTypeImpl_) SQL2(v sql.Value) (sqltypes.Value, error) {
870+
// ToSQLValue implements Type2 interface.
871+
func (t NumberTypeImpl_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) {
872872
if v.IsNull() {
873873
return sqltypes.NULL, nil
874874
}
875875

876-
var val []byte
877876
switch t.baseType {
878877
case sqltypes.Int8:
879878
x := values.ReadInt8(v.Val)
880-
val = []byte(strconv.FormatInt(int64(x), 10))
879+
dest = strconv.AppendInt(dest, int64(x), 10)
881880
case sqltypes.Int16:
882881
x := values.ReadInt16(v.Val)
883-
val = []byte(strconv.FormatInt(int64(x), 10))
882+
dest = strconv.AppendInt(dest, int64(x), 10)
884883
case sqltypes.Int24:
885884
x := values.ReadInt24(v.Val)
886-
val = []byte(strconv.FormatInt(int64(x), 10))
885+
dest = strconv.AppendInt(dest, int64(x), 10)
887886
case sqltypes.Int32:
888887
x := values.ReadInt32(v.Val)
889-
val = []byte(strconv.FormatInt(int64(x), 10))
888+
dest = strconv.AppendInt(dest, int64(x), 10)
890889
case sqltypes.Int64:
891890
x := values.ReadInt64(v.Val)
892-
val = []byte(strconv.FormatInt(x, 10))
891+
dest = strconv.AppendInt(dest, x, 10)
893892
case sqltypes.Uint8:
894893
x := values.ReadUint8(v.Val)
895-
val = []byte(strconv.FormatUint(uint64(x), 10))
894+
dest = strconv.AppendUint(dest, uint64(x), 10)
896895
case sqltypes.Uint16:
897896
x := values.ReadUint16(v.Val)
898-
val = []byte(strconv.FormatUint(uint64(x), 10))
897+
dest = strconv.AppendUint(dest, uint64(x), 10)
899898
case sqltypes.Uint24:
900899
x := values.ReadUint24(v.Val)
901-
val = []byte(strconv.FormatUint(uint64(x), 10))
900+
dest = strconv.AppendUint(dest, uint64(x), 10)
902901
case sqltypes.Uint32:
903902
x := values.ReadUint32(v.Val)
904-
val = []byte(strconv.FormatUint(uint64(x), 10))
903+
dest = strconv.AppendUint(dest, uint64(x), 10)
905904
case sqltypes.Uint64:
906905
x := values.ReadUint64(v.Val)
907-
val = []byte(strconv.FormatUint(x, 10))
906+
dest = strconv.AppendUint(dest, x, 10)
908907
case sqltypes.Float32:
909908
x := values.ReadFloat32(v.Val)
910-
val = []byte(strconv.FormatFloat(float64(x), 'f', -1, 32))
909+
dest = strconv.AppendFloat(dest, float64(x), 'f', -1, 32)
911910
case sqltypes.Float64:
912911
x := values.ReadFloat64(v.Val)
913-
val = []byte(strconv.FormatFloat(x, 'f', -1, 64))
912+
dest = strconv.AppendFloat(dest, x, 'f', -1, 64)
914913
default:
915914
panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number"))
916915
}
917916

918-
return sqltypes.MakeTrusted(t.baseType, val), nil
917+
return sqltypes.MakeTrusted(t.baseType, dest), nil
919918
}
920919

921920
// String implements Type interface.

sql/types/strings.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,26 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.
790790
return sqltypes.MakeTrusted(t.baseType, val), nil
791791
}
792792

793+
// ToSQLValue implements ValueType interface.
794+
func (t StringType) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) {
795+
if v.IsNull() {
796+
return sqltypes.NULL, nil
797+
}
798+
799+
// TODO: collations
800+
// TODO: deal with casting numbers?
801+
// No need to use dest buffer as we have already allocated []byte
802+
var err error
803+
if v.Val == nil && v.WrappedVal != nil {
804+
v.Val, err = v.WrappedVal.Unwrap(ctx)
805+
if err != nil {
806+
return sqltypes.Value{}, err
807+
}
808+
}
809+
810+
return sqltypes.MakeTrusted(t.baseType, v.Val), nil
811+
}
812+
793813
// String implements Type interface.
794814
func (t StringType) String() string {
795815
return t.StringWithTableCollation(sql.Collation_Default)

sql/types/time.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package types
1616

1717
import (
1818
"context"
19+
"github.com/dolthub/go-mysql-server/sql/values"
1920
"math"
2021
"reflect"
2122
"strconv"
@@ -267,6 +268,16 @@ func (t TimespanType_) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes
267268
return sqltypes.MakeTrusted(sqltypes.Time, val), nil
268269
}
269270

271+
func (t TimespanType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) {
272+
if v.IsNull() {
273+
return sqltypes.NULL, nil
274+
}
275+
x := values.ReadInt64(v.Val)
276+
// TODO: write version of this that takes advantage of dest
277+
v.Val = Timespan(x).Bytes()
278+
return sqltypes.MakeTrusted(sqltypes.Time, v.Val), nil
279+
}
280+
270281
// String implements Type interface.
271282
func (t TimespanType_) String() string {
272283
return "time(6)"

sql/types/year.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package types
1616

1717
import (
1818
"context"
19+
"github.com/dolthub/go-mysql-server/sql/values"
1920
"reflect"
2021
"strconv"
2122
"time"
@@ -171,6 +172,15 @@ func (t YearType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.V
171172
return sqltypes.MakeTrusted(sqltypes.Year, val), nil
172173
}
173174

175+
func (t YearType_) ToSQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) {
176+
if v.IsNull() {
177+
return sqltypes.NULL, nil
178+
}
179+
x := values.ReadUint8(v.Val)
180+
dest = strconv.AppendInt(dest, int64(x), 10)
181+
return sqltypes.MakeTrusted(sqltypes.Year, dest), nil
182+
}
183+
174184
// String implements Type interface.
175185
func (t YearType_) String() string {
176186
return "year"

0 commit comments

Comments
 (0)