33package driver
44
55import (
6+ "context"
67 "crypto/tls"
78 "database/sql"
89 sqldriver "database/sql/driver"
@@ -21,6 +22,23 @@ import (
2122 "github.com/pingcap/errors"
2223)
2324
25+ var (
26+ _ sqldriver.Driver = & driver {}
27+ _ sqldriver.DriverContext = & driver {}
28+ _ sqldriver.Connector = & connInfo {}
29+ _ sqldriver.NamedValueChecker = & conn {}
30+ _ sqldriver.Validator = & conn {}
31+ _ sqldriver.Conn = & conn {}
32+ _ sqldriver.Pinger = & conn {}
33+ _ sqldriver.ConnBeginTx = & conn {}
34+ _ sqldriver.ConnPrepareContext = & conn {}
35+ _ sqldriver.ExecerContext = & conn {}
36+ _ sqldriver.QueryerContext = & conn {}
37+ _ sqldriver.Stmt = & stmt {}
38+ _ sqldriver.StmtExecContext = & stmt {}
39+ _ sqldriver.StmtQueryContext = & stmt {}
40+ )
41+
2442var customTLSMutex sync.Mutex
2543
2644// Map of dsn address (makes more sense than full dsn?) to tls Config
@@ -101,16 +119,18 @@ func parseDSN(dsn string) (connInfo, error) {
101119// Open takes a supplied DSN string and opens a connection
102120// See ParseDSN for more information on the form of the DSN
103121func (d driver ) Open (dsn string ) (sqldriver.Conn , error ) {
104- var (
105- c * client.Conn
106- // by default database/sql driver retries will be enabled
107- retries = true
108- )
109-
110122 ci , err := parseDSN (dsn )
111123 if err != nil {
112124 return nil , err
113125 }
126+ return ci .Connect (context .Background ())
127+ }
128+
129+ func (ci connInfo ) Connect (ctx context.Context ) (sqldriver.Conn , error ) {
130+ var c * client.Conn
131+ var err error
132+ // by default database/sql driver retries will be enabled
133+ retries := true
114134
115135 if ci .standardDSN {
116136 var timeout time.Duration
@@ -159,45 +179,86 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
159179 }
160180 }
161181
162- if timeout > 0 {
163- c , err = client .ConnectWithTimeout (ci .addr , ci .user , ci .password , ci .db , timeout , configuredOptions ... )
164- } else {
165- c , err = client .Connect (ci .addr , ci .user , ci .password , ci .db , configuredOptions ... )
182+ if timeout <= 0 {
183+ timeout = 10 * time .Second
166184 }
185+ c , err = client .ConnectWithContext (ctx , ci .addr , ci .user , ci .password , ci .db , timeout , configuredOptions ... )
167186 } else {
168187 // No more processing here. Let's only support url parameters with the newer style DSN
169- c , err = client .Connect ( ci .addr , ci .user , ci .password , ci .db )
188+ c , err = client .ConnectWithContext ( ctx , ci .addr , ci .user , ci .password , ci .db , 10 * time . Second )
170189 }
171190 if err != nil {
172191 return nil , err
173192 }
174193
194+ contexts := make (chan context.Context )
195+ go func () {
196+ ctx := context .Background ()
197+ for {
198+ var ok bool
199+ select {
200+ case <- ctx .Done ():
201+ ctx = context .Background ()
202+ _ = c .Conn .Close ()
203+ case ctx , ok = <- contexts :
204+ if ! ok {
205+ return
206+ }
207+ }
208+ }
209+ }()
210+
175211 // if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3
176212 // retries by the database/sql package. If retries are 'off' then we'll return
177213 // the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry.
178214 // In this case the sqldriver.Validator interface is implemented and will return
179215 // false for IsValid() signaling the connection is bad and should be discarded.
180- return & conn {Conn : c , state : & state {valid : true , useStdLibErrors : retries }}, nil
216+ return & conn {
217+ Conn : c ,
218+ state : & state {contexts : contexts , valid : true , useStdLibErrors : retries },
219+ }, nil
181220}
182221
183- type CheckNamedValueFunc func (* sqldriver.NamedValue ) error
222+ func (d driver ) OpenConnector (name string ) (sqldriver.Connector , error ) {
223+ return parseDSN (name )
224+ }
184225
185- var (
186- _ sqldriver.NamedValueChecker = & conn {}
187- _ sqldriver.Validator = & conn {}
188- )
226+ func (ci connInfo ) Driver () sqldriver.Driver {
227+ return driver {}
228+ }
229+
230+ type CheckNamedValueFunc func (* sqldriver.NamedValue ) error
189231
190232type state struct {
191- valid bool
233+ contexts chan context.Context
234+ valid bool
192235 // when true, the driver connection will return ErrBadConn from the golang Standard Library
193236 useStdLibErrors bool
194237}
195238
239+ func (s * state ) watchCtx (ctx context.Context ) func () {
240+ s .contexts <- ctx
241+ return func () {
242+ s .contexts <- context .Background ()
243+ }
244+ }
245+
246+ func (s * state ) Close () {
247+ if s .contexts != nil {
248+ close (s .contexts )
249+ s .contexts = nil
250+ }
251+ }
252+
196253type conn struct {
197254 * client.Conn
198255 state * state
199256}
200257
258+ func (c * conn ) watchCtx (ctx context.Context ) func () {
259+ return c .state .watchCtx (ctx )
260+ }
261+
201262func (c * conn ) CheckNamedValue (nv * sqldriver.NamedValue ) error {
202263 for _ , nvChecker := range namedValueCheckers {
203264 err := nvChecker (nv )
@@ -220,6 +281,17 @@ func (c *conn) IsValid() bool {
220281 return c .state .valid
221282}
222283
284+ func (c * conn ) Ping (ctx context.Context ) error {
285+ defer c .watchCtx (ctx )()
286+ if err := c .Conn .Ping (); err != nil {
287+ if err == context .DeadlineExceeded || err == context .Canceled {
288+ return err
289+ }
290+ return sqldriver .ErrBadConn
291+ }
292+ return nil
293+ }
294+
223295func (c * conn ) Prepare (query string ) (sqldriver.Stmt , error ) {
224296 st , err := c .Conn .Prepare (query )
225297 if err != nil {
@@ -229,7 +301,13 @@ func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
229301 return & stmt {Stmt : st , connectionState : c .state }, nil
230302}
231303
304+ func (c * conn ) PrepareContext (ctx context.Context , query string ) (sqldriver.Stmt , error ) {
305+ defer c .watchCtx (ctx )()
306+ return c .Prepare (query )
307+ }
308+
232309func (c * conn ) Close () error {
310+ c .state .Close ()
233311 return c .Conn .Close ()
234312}
235313
@@ -242,6 +320,29 @@ func (c *conn) Begin() (sqldriver.Tx, error) {
242320 return & tx {c .Conn }, nil
243321}
244322
323+ var isolationLevelTransactionIsolation = map [sql.IsolationLevel ]string {
324+ sql .LevelDefault : "" ,
325+ sql .LevelRepeatableRead : "REPEATABLE READ" ,
326+ sql .LevelReadCommitted : "READ COMMITTED" ,
327+ sql .LevelReadUncommitted : "READ UNCOMMITTED" ,
328+ sql .LevelSerializable : "SERIALIZABLE" ,
329+ }
330+
331+ func (c * conn ) BeginTx (ctx context.Context , opts sqldriver.TxOptions ) (sqldriver.Tx , error ) {
332+ defer c .watchCtx (ctx )()
333+
334+ isolation := sql .IsolationLevel (opts .Isolation )
335+ txIsolation , ok := isolationLevelTransactionIsolation [isolation ]
336+ if ! ok {
337+ return nil , fmt .Errorf ("invalid mysql transaction isolation level %s" , isolation )
338+ }
339+ err := c .Conn .BeginTx (opts .ReadOnly , txIsolation )
340+ if err != nil {
341+ return nil , errors .Trace (err )
342+ }
343+ return & tx {c .Conn }, nil
344+ }
345+
245346func buildArgs (args []sqldriver.Value ) []interface {} {
246347 a := make ([]interface {}, len (args ))
247348
@@ -252,6 +353,16 @@ func buildArgs(args []sqldriver.Value) []interface{} {
252353 return a
253354}
254355
356+ func buildNamedArgs (args []sqldriver.NamedValue ) []interface {} {
357+ a := make ([]interface {}, len (args ))
358+
359+ for i , arg := range args {
360+ a [i ] = arg .Value
361+ }
362+
363+ return a
364+ }
365+
255366func (st * state ) replyError (err error ) error {
256367 isBadConnection := mysql .ErrorEqual (err , mysql .ErrBadConn )
257368
@@ -275,6 +386,16 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err
275386 return & result {r }, nil
276387}
277388
389+ func (c * conn ) ExecContext (ctx context.Context , query string , args []sqldriver.NamedValue ) (sqldriver.Result , error ) {
390+ defer c .watchCtx (ctx )()
391+ a := buildNamedArgs (args )
392+ r , err := c .Conn .Execute (query , a ... )
393+ if err != nil {
394+ return nil , c .state .replyError (err )
395+ }
396+ return & result {r }, nil
397+ }
398+
278399func (c * conn ) Query (query string , args []sqldriver.Value ) (sqldriver.Rows , error ) {
279400 a := buildArgs (args )
280401 r , err := c .Conn .Execute (query , a ... )
@@ -284,11 +405,25 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
284405 return newRows (r .Resultset )
285406}
286407
408+ func (c * conn ) QueryContext (ctx context.Context , query string , args []sqldriver.NamedValue ) (sqldriver.Rows , error ) {
409+ defer c .watchCtx (ctx )()
410+ a := buildNamedArgs (args )
411+ r , err := c .Conn .Execute (query , a ... )
412+ if err != nil {
413+ return nil , c .state .replyError (err )
414+ }
415+ return newRows (r .Resultset )
416+ }
417+
287418type stmt struct {
288419 * client.Stmt
289420 connectionState * state
290421}
291422
423+ func (s * stmt ) watchCtx (ctx context.Context ) func () {
424+ return s .connectionState .watchCtx (ctx )
425+ }
426+
292427func (s * stmt ) Close () error {
293428 return s .Stmt .Close ()
294429}
@@ -306,6 +441,17 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
306441 return & result {r }, nil
307442}
308443
444+ func (s * stmt ) ExecContext (ctx context.Context , args []sqldriver.NamedValue ) (sqldriver.Result , error ) {
445+ defer s .watchCtx (ctx )()
446+
447+ a := buildNamedArgs (args )
448+ r , err := s .Stmt .Execute (a ... )
449+ if err != nil {
450+ return nil , s .connectionState .replyError (err )
451+ }
452+ return & result {r }, nil
453+ }
454+
309455func (s * stmt ) Query (args []sqldriver.Value ) (sqldriver.Rows , error ) {
310456 a := buildArgs (args )
311457 r , err := s .Stmt .Execute (a ... )
@@ -315,6 +461,17 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
315461 return newRows (r .Resultset )
316462}
317463
464+ func (s * stmt ) QueryContext (ctx context.Context , args []sqldriver.NamedValue ) (sqldriver.Rows , error ) {
465+ defer s .watchCtx (ctx )()
466+
467+ a := buildNamedArgs (args )
468+ r , err := s .Stmt .Execute (a ... )
469+ if err != nil {
470+ return nil , s .connectionState .replyError (err )
471+ }
472+ return newRows (r .Resultset )
473+ }
474+
318475type tx struct {
319476 * client.Conn
320477}
0 commit comments