Skip to content

Commit 206f30e

Browse files
author
James Cor
committed
row2 changes
1 parent bbd0659 commit 206f30e

File tree

16 files changed

+470
-243
lines changed

16 files changed

+470
-243
lines changed

server/handler.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,8 @@ func (h *Handler) doQuery(
495495
r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields)
496496
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
497497
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf)
498+
} else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) {
499+
r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more)
498500
} else {
499501
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf)
500502
}
@@ -768,6 +770,135 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
768770
return r, processedAtLeastOneBatch, nil
769771
}
770772

773+
func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) {
774+
defer trace.StartRegion(ctx, "Handler.resultForDefaultIter2").End()
775+
776+
eg, ctx := ctx.NewErrgroup()
777+
pan2err := func(err *error) {
778+
if recoveredPanic := recover(); recoveredPanic != nil {
779+
stack := debug.Stack()
780+
wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, stack)
781+
*err = goerrors.Join(*err, wrappedErr)
782+
}
783+
}
784+
785+
// TODO: poll for closed connections should obviously also run even if
786+
// we're doing something with an OK result or a single row result, etc.
787+
// This should be in the caller.
788+
pollCtx, cancelF := ctx.NewSubContext()
789+
eg.Go(func() (err error) {
790+
defer pan2err(&err)
791+
return h.pollForClosedConnection(pollCtx, c)
792+
})
793+
794+
// Default waitTime is one minute if there is no timeout configured, in which case
795+
// it will loop to iterate again unless the socket died by the OS timeout or other problems.
796+
// If there is a timeout, it will be enforced to ensure that Vitess has a chance to
797+
// call Handler.CloseConnection()
798+
waitTime := 1 * time.Minute
799+
if h.readTimeout > 0 {
800+
waitTime = h.readTimeout
801+
}
802+
timer := time.NewTimer(waitTime)
803+
defer timer.Stop()
804+
805+
wg := sync.WaitGroup{}
806+
wg.Add(2)
807+
808+
// TODO: send results instead of rows?
809+
// Read rows from iter and send them off
810+
var rowChan = make(chan sql.Row2, 512)
811+
eg.Go(func() (err error) {
812+
defer pan2err(&err)
813+
defer wg.Done()
814+
defer close(rowChan)
815+
for {
816+
select {
817+
case <-ctx.Done():
818+
return context.Cause(ctx)
819+
default:
820+
row, err := iter.Next2(ctx)
821+
if err == io.EOF {
822+
return nil
823+
}
824+
if err != nil {
825+
return err
826+
}
827+
select {
828+
case rowChan <- row:
829+
case <-ctx.Done():
830+
return nil
831+
}
832+
}
833+
}
834+
})
835+
836+
var res *sqltypes.Result
837+
var processedAtLeastOneBatch bool
838+
eg.Go(func() (err error) {
839+
defer pan2err(&err)
840+
defer cancelF()
841+
defer wg.Done()
842+
for {
843+
if res == nil {
844+
res = &sqltypes.Result{
845+
Fields: resultFields,
846+
Rows: make([][]sqltypes.Value, 0, rowsBatch),
847+
}
848+
}
849+
if res.RowsAffected == rowsBatch {
850+
if err := callback(res, more); err != nil {
851+
return err
852+
}
853+
res = nil
854+
processedAtLeastOneBatch = true
855+
continue
856+
}
857+
858+
select {
859+
case <-ctx.Done():
860+
return context.Cause(ctx)
861+
case <-timer.C:
862+
if h.readTimeout != 0 {
863+
// Cancel and return so Vitess can call the CloseConnection callback
864+
ctx.GetLogger().Tracef("connection timeout")
865+
return ErrRowTimeout.New()
866+
}
867+
case row, ok := <-rowChan:
868+
if !ok {
869+
return nil
870+
}
871+
ctx.GetLogger().Tracef("spooling result row %s", row)
872+
res.Rows = append(res.Rows, row)
873+
res.RowsAffected++
874+
if !timer.Stop() {
875+
<-timer.C
876+
}
877+
}
878+
timer.Reset(waitTime)
879+
}
880+
})
881+
882+
// Close() kills this PID in the process list,
883+
// wait until all rows have be sent over the wire
884+
eg.Go(func() (err error) {
885+
defer pan2err(&err)
886+
wg.Wait()
887+
return iter.Close(ctx)
888+
})
889+
890+
err := eg.Wait()
891+
if err != nil {
892+
ctx.GetLogger().WithError(err).Warn("error running query")
893+
if verboseErrorLogging {
894+
fmt.Printf("Err: %+v", err)
895+
}
896+
return nil, false, err
897+
}
898+
899+
return res, processedAtLeastOneBatch, nil
900+
}
901+
771902
// See https://dev.mysql.com/doc/internals/en/status-flags.html
772903
func setConnStatusFlags(ctx *sql.Context, c *mysql.Conn) error {
773904
ok, err := isSessionAutocommit(ctx)

sql/convert_value.go

Lines changed: 20 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,98 +3,46 @@ package sql
33
import (
44
"fmt"
55

6-
"github.com/dolthub/vitess/go/vt/proto/query"
7-
86
"github.com/dolthub/go-mysql-server/sql/values"
7+
8+
"github.com/dolthub/vitess/go/sqltypes"
9+
"github.com/dolthub/vitess/go/vt/proto/query"
910
)
1011

1112
// ConvertToValue converts the interface to a sql value.
12-
func ConvertToValue(v interface{}) (Value, error) {
13+
func ConvertToValue(v interface{}) (sqltypes.Value, error) {
1314
switch v := v.(type) {
1415
case nil:
15-
return Value{
16-
Typ: query.Type_NULL_TYPE,
17-
Val: nil,
18-
}, nil
16+
return sqltypes.MakeTrusted(query.Type_NULL_TYPE, nil), nil
1917
case int:
20-
return Value{
21-
Typ: query.Type_INT64,
22-
Val: values.WriteInt64(make([]byte, values.Int64Size), int64(v)),
23-
}, nil
18+
return sqltypes.MakeTrusted(query.Type_INT64, values.WriteInt64(make([]byte, values.Int64Size), int64(v))), nil
2419
case int8:
25-
return Value{
26-
Typ: query.Type_INT8,
27-
Val: values.WriteInt8(make([]byte, values.Int8Size), v),
28-
}, nil
20+
return sqltypes.MakeTrusted(query.Type_INT8, values.WriteInt8(make([]byte, values.Int8Size), v)), nil
2921
case int16:
30-
return Value{
31-
Typ: query.Type_INT16,
32-
Val: values.WriteInt16(make([]byte, values.Int16Size), v),
33-
}, nil
22+
return sqltypes.MakeTrusted(query.Type_INT16, values.WriteInt16(make([]byte, values.Int16Size), v)), nil
3423
case int32:
35-
return Value{
36-
Typ: query.Type_INT32,
37-
Val: values.WriteInt32(make([]byte, values.Int32Size), v),
38-
}, nil
24+
return sqltypes.MakeTrusted(query.Type_INT32, values.WriteInt32(make([]byte, values.Int32Size), v)), nil
3925
case int64:
40-
return Value{
41-
Typ: query.Type_INT64,
42-
Val: values.WriteInt64(make([]byte, values.Int64Size), v),
43-
}, nil
26+
return sqltypes.MakeTrusted(query.Type_INT64, values.WriteInt64(make([]byte, values.Int64Size), v)), nil
4427
case uint:
45-
return Value{
46-
Typ: query.Type_UINT64,
47-
Val: values.WriteUint64(make([]byte, values.Uint64Size), uint64(v)),
48-
}, nil
28+
return sqltypes.MakeTrusted(query.Type_UINT64, values.WriteUint64(make([]byte, values.Uint64Size), uint64(v))), nil
4929
case uint8:
50-
return Value{
51-
Typ: query.Type_UINT8,
52-
Val: values.WriteUint8(make([]byte, values.Uint8Size), v),
53-
}, nil
30+
return sqltypes.MakeTrusted(query.Type_UINT8, values.WriteUint8(make([]byte, values.Uint8Size), v)), nil
5431
case uint16:
55-
return Value{
56-
Typ: query.Type_UINT16,
57-
Val: values.WriteUint16(make([]byte, values.Uint16Size), v),
58-
}, nil
32+
return sqltypes.MakeTrusted(query.Type_UINT16, values.WriteUint16(make([]byte, values.Uint16Size), v)), nil
5933
case uint32:
60-
return Value{
61-
Typ: query.Type_UINT32,
62-
Val: values.WriteUint32(make([]byte, values.Uint32Size), v),
63-
}, nil
34+
return sqltypes.MakeTrusted(query.Type_UINT32, values.WriteUint32(make([]byte, values.Uint32Size), v)), nil
6435
case uint64:
65-
return Value{
66-
Typ: query.Type_UINT64,
67-
Val: values.WriteUint64(make([]byte, values.Uint64Size), v),
68-
}, nil
36+
return sqltypes.MakeTrusted(query.Type_UINT64, values.WriteUint64(make([]byte, values.Uint64Size), v)), nil
6937
case float32:
70-
return Value{
71-
Typ: query.Type_FLOAT32,
72-
Val: values.WriteFloat32(make([]byte, values.Float32Size), v),
73-
}, nil
38+
return sqltypes.MakeTrusted(query.Type_FLOAT32, values.WriteFloat32(make([]byte, values.Float32Size), v)), nil
7439
case float64:
75-
return Value{
76-
Typ: query.Type_FLOAT64,
77-
Val: values.WriteFloat64(make([]byte, values.Float64Size), v),
78-
}, nil
40+
return sqltypes.MakeTrusted(query.Type_FLOAT64, values.WriteFloat64(make([]byte, values.Float64Size), v)), nil
7941
case string:
80-
return Value{
81-
Typ: query.Type_VARCHAR,
82-
Val: values.WriteString(make([]byte, len(v)), v, values.ByteOrderCollation),
83-
}, nil
42+
return sqltypes.MakeTrusted(query.Type_VARCHAR, values.WriteString(make([]byte, len(v)), v, values.ByteOrderCollation)), nil
8443
case []byte:
85-
return Value{
86-
Typ: query.Type_BLOB,
87-
Val: values.WriteBytes(make([]byte, len(v)), v, values.ByteOrderCollation),
88-
}, nil
44+
return sqltypes.MakeTrusted(query.Type_BLOB, values.WriteBytes(make([]byte, len(v)), v, values.ByteOrderCollation)), nil
8945
default:
90-
return Value{}, fmt.Errorf("type %T not implemented", v)
91-
}
92-
}
93-
94-
func MustConvertToValue(v interface{}) Value {
95-
ret, err := ConvertToValue(v)
96-
if err != nil {
97-
panic(err)
46+
return sqltypes.Value{}, fmt.Errorf("type %T not implemented", v)
9847
}
99-
return ret
10048
}

sql/core.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"time"
2727
"unsafe"
2828

29+
"github.com/dolthub/vitess/go/sqltypes"
2930
"github.com/shopspring/decimal"
3031
"gopkg.in/src-d/go-errors.v1"
3132

@@ -464,9 +465,10 @@ func DebugString(nodeOrExpression interface{}) string {
464465
type Expression2 interface {
465466
Expression
466467
// Eval2 evaluates the given row frame and returns a result.
467-
Eval2(ctx *Context, row Row2) (Value, error)
468+
Eval2(ctx *Context, row Row2) (sqltypes.Value, error)
468469
// Type2 returns the expression type.
469470
Type2() Type2
471+
IsExpr2() bool
470472
}
471473

472474
var SystemVariables SystemVariableRegistry

sql/expression/comparison.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ package expression
1717
import (
1818
"fmt"
1919

20+
"github.com/dolthub/vitess/go/sqltypes"
21+
querypb "github.com/dolthub/vitess/go/vt/proto/query"
2022
errors "gopkg.in/src-d/go-errors.v1"
2123

2224
"github.com/dolthub/go-mysql-server/sql"
@@ -492,6 +494,7 @@ type GreaterThan struct {
492494
}
493495

494496
var _ sql.Expression = (*GreaterThan)(nil)
497+
var _ sql.Expression2 = (*GreaterThan)(nil)
495498
var _ sql.CollationCoercible = (*GreaterThan)(nil)
496499

497500
// NewGreaterThan creates a new GreaterThan expression.
@@ -518,6 +521,65 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
518521
return result == 1, nil
519522
}
520523

524+
func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) {
525+
l, ok := gt.Left().(sql.Expression2)
526+
if !ok {
527+
panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Left()))
528+
}
529+
r, ok := gt.Right().(sql.Expression2)
530+
if !ok {
531+
panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Right()))
532+
}
533+
534+
lv, err := l.Eval2(ctx, row)
535+
if err != nil {
536+
return sqltypes.Value{}, err
537+
}
538+
rv, err := r.Eval2(ctx, row)
539+
if err != nil {
540+
return sqltypes.Value{}, err
541+
}
542+
543+
// TODO: just assume they are int64
544+
l64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, lv)
545+
if err != nil {
546+
return sqltypes.Value{}, err
547+
}
548+
r64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, rv)
549+
if err != nil {
550+
return sqltypes.Value{}, err
551+
}
552+
var rb byte
553+
if l64 > r64 {
554+
rb = 1
555+
}
556+
557+
ret := sqltypes.MakeTrusted(querypb.Type_INT8, []byte{rb})
558+
return ret, nil
559+
}
560+
561+
func (gt *GreaterThan) Type2() sql.Type2 {
562+
return nil
563+
}
564+
565+
func (gt *GreaterThan) IsExpr2() bool {
566+
lExpr, isExpr2 := gt.Left().(sql.Expression2)
567+
if !isExpr2 {
568+
return false
569+
}
570+
if !lExpr.IsExpr2() {
571+
return false
572+
}
573+
rExpr, isExpr2 := gt.Right().(sql.Expression2)
574+
if !isExpr2 {
575+
return false
576+
}
577+
if !rExpr.IsExpr2() {
578+
return false
579+
}
580+
return true
581+
}
582+
521583
// WithChildren implements the Expression interface.
522584
func (gt *GreaterThan) WithChildren(children ...sql.Expression) (sql.Expression, error) {
523585
if len(children) != 2 {

sql/expression/get_field.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"fmt"
1919
"strings"
2020

21+
"github.com/dolthub/vitess/go/sqltypes"
2122
errors "gopkg.in/src-d/go-errors.v1"
2223

2324
"github.com/dolthub/go-mysql-server/sql"
@@ -149,12 +150,15 @@ func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
149150
return row[p.fieldIndex], nil
150151
}
151152

152-
func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) {
153+
func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) {
153154
if p.fieldIndex < 0 || p.fieldIndex >= row.Len() {
154-
return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len())
155+
return sqltypes.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len())
155156
}
157+
return row[p.fieldIndex], nil
158+
}
156159

157-
return row.GetField(p.fieldIndex), nil
160+
func (p *GetField) IsExpr2() bool {
161+
return true
158162
}
159163

160164
// WithChildren implements the Expression interface.

0 commit comments

Comments
 (0)