Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ GitHub Inc.
Google Inc.
InfoSum Ltd.
Keybase Inc.
Microsoft Corp.
Multiplay Ltd.
Percona LLC
PingCAP Inc.
Expand Down
12 changes: 11 additions & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,22 @@ func newConnector(cfg *Config) *connector {
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error

// Invoke BeforeConnect if present, with a copy of the configuration
cfg := c.cfg
if c.cfg.BeforeConnect != nil {
cfg = c.cfg.Clone()
err = c.cfg.BeforeConnect(ctx, cfg)
if err != nil {
return nil, err
}
}

// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: c.cfg,
cfg: cfg,
connector: c,
}
mc.parseTime = mc.cfg.ParseTime
Expand Down
34 changes: 34 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1981,6 +1981,40 @@ func TestCustomDial(t *testing.T) {
}
}

func TestBeforeConnect(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}

// dbname is set in the BeforeConnect handle
cfg, err := ParseDSN(fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, "_"))
if err != nil {
t.Fatalf("error parsing DSN: %v", err)
}

cfg.BeforeConnect = func(ctx context.Context, c *Config) error {
c.DBName = dbname
return nil
}

connector, err := NewConnector(cfg)
if err != nil {
t.Fatalf("error creating connector: %v", err)
}

db := sql.OpenDB(connector)
defer db.Close()

var connectedDb string
err = db.QueryRow("SELECT DATABASE();").Scan(&connectedDb)
if err != nil {
t.Fatalf("error executing query: %v", err)
}
if connectedDb != dbname {
t.Fatalf("expected to connect to DB %s, but connected to %s instead", dbname, connectedDb)
}
}

func TestSQLInjection(t *testing.T) {
createTest := func(arg string) func(dbt *DBTest) {
return func(dbt *DBTest) {
Expand Down
38 changes: 20 additions & 18 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package mysql

import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"errors"
Expand All @@ -34,24 +35,25 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
DBName string // Database name
Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger
User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
DBName string // Database name
Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger
BeforeConnect func(context.Context, *Config) error // Invoked before a connection is established

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
Expand Down