1717package mysql
1818
1919import (
20+ "context"
2021 "database/sql"
2122 "database/sql/driver"
2223 "net"
@@ -29,139 +30,54 @@ type MySQLDriver struct{}
2930
3031// DialFunc is a function which can be used to establish the network connection.
3132// Custom dial functions must be registered with RegisterDial
33+ //
34+ // Deprecated: users should register a DialContextFunc instead
3235type DialFunc func (addr string ) (net.Conn , error )
3336
37+ // DialContextFunc is a function which can be used to establish the network connection.
38+ // Custom dial functions must be registered with RegisterDialContext
39+ type DialContextFunc func (ctx context.Context , addr string ) (net.Conn , error )
40+
3441var (
3542 dialsLock sync.RWMutex
36- dials map [string ]DialFunc
43+ dials map [string ]DialContextFunc
3744)
3845
39- // RegisterDial registers a custom dial function. It can then be used by the
46+ // RegisterDialContext registers a custom dial function. It can then be used by the
4047// network address mynet(addr), where mynet is the registered new network.
41- // addr is passed as a parameter to the dial function.
42- func RegisterDial (net string , dial DialFunc ) {
48+ // The current context for the connection and its address is passed to the dial function.
49+ func RegisterDialContext (net string , dial DialContextFunc ) {
4350 dialsLock .Lock ()
4451 defer dialsLock .Unlock ()
4552 if dials == nil {
46- dials = make (map [string ]DialFunc )
53+ dials = make (map [string ]DialContextFunc )
4754 }
4855 dials [net ] = dial
4956}
5057
58+ // RegisterDial registers a custom dial function. It can then be used by the
59+ // network address mynet(addr), where mynet is the registered new network.
60+ // addr is passed as a parameter to the dial function.
61+ //
62+ // Deprecated: users should call RegisterDialContext instead
63+ func RegisterDial (network string , dial DialFunc ) {
64+ RegisterDialContext (network , func (_ context.Context , addr string ) (net.Conn , error ) {
65+ return dial (addr )
66+ })
67+ }
68+
5169// Open new Connection.
5270// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
5371// the DSN string is formatted
5472func (d MySQLDriver ) Open (dsn string ) (driver.Conn , error ) {
55- var err error
56-
57- // New mysqlConn
58- mc := & mysqlConn {
59- maxAllowedPacket : maxPacketSize ,
60- maxWriteSize : maxPacketSize - 1 ,
61- closech : make (chan struct {}),
62- }
63- mc .cfg , err = ParseDSN (dsn )
64- if err != nil {
65- return nil , err
66- }
67- mc .parseTime = mc .cfg .ParseTime
68-
69- // Connect to Server
70- dialsLock .RLock ()
71- dial , ok := dials [mc .cfg .Net ]
72- dialsLock .RUnlock ()
73- if ok {
74- mc .netConn , err = dial (mc .cfg .Addr )
75- } else {
76- nd := net.Dialer {Timeout : mc .cfg .Timeout }
77- mc .netConn , err = nd .Dial (mc .cfg .Net , mc .cfg .Addr )
78- }
79- if err != nil {
80- if nerr , ok := err .(net.Error ); ok && nerr .Temporary () {
81- errLog .Print ("net.Error from Dial()': " , nerr .Error ())
82- return nil , driver .ErrBadConn
83- }
84- return nil , err
85- }
86-
87- // Enable TCP Keepalives on TCP connections
88- if tc , ok := mc .netConn .(* net.TCPConn ); ok {
89- if err := tc .SetKeepAlive (true ); err != nil {
90- // Don't send COM_QUIT before handshake.
91- mc .netConn .Close ()
92- mc .netConn = nil
93- return nil , err
94- }
95- }
96-
97- // Call startWatcher for context support (From Go 1.8)
98- mc .startWatcher ()
99-
100- mc .buf = newBuffer (mc .netConn )
101-
102- // Set I/O timeouts
103- mc .buf .timeout = mc .cfg .ReadTimeout
104- mc .writeTimeout = mc .cfg .WriteTimeout
105-
106- // Reading Handshake Initialization Packet
107- authData , plugin , err := mc .readHandshakePacket ()
73+ cfg , err := ParseDSN (dsn )
10874 if err != nil {
109- mc .cleanup ()
11075 return nil , err
11176 }
112- if plugin == "" {
113- plugin = defaultAuthPlugin
77+ c := & connector {
78+ cfg : cfg ,
11479 }
115-
116- // Send Client Authentication Packet
117- authResp , err := mc .auth (authData , plugin )
118- if err != nil {
119- // try the default auth plugin, if using the requested plugin failed
120- errLog .Print ("could not use requested auth plugin '" + plugin + "': " , err .Error ())
121- plugin = defaultAuthPlugin
122- authResp , err = mc .auth (authData , plugin )
123- if err != nil {
124- mc .cleanup ()
125- return nil , err
126- }
127- }
128- if err = mc .writeHandshakeResponsePacket (authResp , plugin ); err != nil {
129- mc .cleanup ()
130- return nil , err
131- }
132-
133- // Handle response to auth packet, switch methods if possible
134- if err = mc .handleAuthResult (authData , plugin ); err != nil {
135- // Authentication failed and MySQL has already closed the connection
136- // (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
137- // Do not send COM_QUIT, just cleanup and return the error.
138- mc .cleanup ()
139- return nil , err
140- }
141-
142- if mc .cfg .MaxAllowedPacket > 0 {
143- mc .maxAllowedPacket = mc .cfg .MaxAllowedPacket
144- } else {
145- // Get max allowed packet size
146- maxap , err := mc .getSystemVar ("max_allowed_packet" )
147- if err != nil {
148- mc .Close ()
149- return nil , err
150- }
151- mc .maxAllowedPacket = stringToInt (maxap ) - 1
152- }
153- if mc .maxAllowedPacket < maxPacketSize {
154- mc .maxWriteSize = mc .maxAllowedPacket
155- }
156-
157- // Handle DSN Params
158- err = mc .handleParams ()
159- if err != nil {
160- mc .Close ()
161- return nil , err
162- }
163-
164- return mc , nil
80+ return c .Connect (context .Background ())
16581}
16682
16783func init () {
0 commit comments