Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Xiangyu Hu <xiangyu.hu at outlook.com>
Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc>
Expand Down
4 changes: 4 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
}
if err != nil {
if nerr, ok := err.(net.Error); ok && (nerr.Temporary() || nerr.Timeout()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't || nerr.Timeout().

nerr.Temporary() returns True when retry is worthful.
For example, lookup timeout is temporary.

func (e *DNSError) Temporary() bool { return e.IsTimeout || e.IsTemporary }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

the docs say

IsTimeout   bool   // if true, timed out; not all timeouts set this
IsTemporary bool   // if true, error is temporary; not all errors set this; added in Go 1.6

so I thought it would be better to check both given IsTemporary might not be supported

errLog.Print("net.Error from Dial()': ", nerr.Error())
return nil, driver.ErrBadConn
}
return nil, err
}

Expand Down
54 changes: 54 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ type DBTest struct {
db *sql.DB
}

type netErrorMock struct {
temporary bool
timeout bool
}

func (e netErrorMock) Temporary() bool {
return e.temporary
}

func (e netErrorMock) Timeout() bool {
return e.timeout
}

func (e netErrorMock) Error() string {
return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout)
}

func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
Expand Down Expand Up @@ -1801,6 +1818,43 @@ func TestConcurrent(t *testing.T) {
})
}

func testDialError(t *testing.T, dialErr error, expectErr error) {
RegisterDial("mydial", func(addr string) (net.Conn, error) {
return nil, dialErr
})

db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()

_, err = db.Exec("DO 1")
if err != expectErr {
t.Fatalf("was expecting %s. Got: %s", dialErr, err)
}
}

func TestDialUnknownError(t *testing.T) {
testErr := fmt.Errorf("test")
testDialError(t, testErr, testErr)
}

func TestDialNonRetryableNetErr(t *testing.T) {
testErr := netErrorMock{}
testDialError(t, testErr, testErr)
}

func TestDialTimeoutNetErr(t *testing.T) {
testErr := netErrorMock{timeout: true}
testDialError(t, testErr, driver.ErrBadConn)
}

func TestDialTemporaryNetErr(t *testing.T) {
testErr := netErrorMock{temporary: true}
testDialError(t, testErr, driver.ErrBadConn)
}

// Tests custom dial functions
func TestCustomDial(t *testing.T) {
if !available {
Expand Down