Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
defer mc.finish()

if err = mc.writeCommandPacket(comPing); err != nil {
return
return mc.markBadConn(err)
}

return mc.readResultOK()
Expand Down
47 changes: 47 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ package mysql
import (
"context"
"database/sql/driver"
"errors"
"net"
"testing"
)

Expand Down Expand Up @@ -108,3 +110,48 @@ func TestCleanCancel(t *testing.T) {
}
}
}

func TestPingMarkBadConnection(t *testing.T) {
nc := badConnection{err: errors.New("boom")}
ms := &mysqlConn{
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
}

err := ms.Ping(context.Background())

if err != driver.ErrBadConn {
t.Errorf("expected driver.ErrBadConn, got %#v", err)
}
}

func TestPingErrInvalidConn(t *testing.T) {
nc := badConnection{err: errors.New("failed to write"), n: 10}
ms := &mysqlConn{
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
}

err := ms.Ping(context.Background())

if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %#v", err)
}
}

type badConnection struct {
n int
err error
net.Conn
}

func (bc badConnection) Write(b []byte) (n int, err error) {
return bc.n, bc.err
}

func (bc badConnection) Close() error {
return nil
}