@@ -10,9 +10,14 @@ package mysql
1010
1111import (
1212 "bytes"
13+ "crypto/rand"
14+ "crypto/rsa"
15+ "crypto/sha1"
1316 "crypto/tls"
17+ "crypto/x509"
1418 "database/sql/driver"
1519 "encoding/binary"
20+ "encoding/pem"
1621 "errors"
1722 "fmt"
1823 "io"
@@ -154,24 +159,24 @@ func (mc *mysqlConn) writePacket(data []byte) error {
154159
155160// Handshake Initialization Packet
156161// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
157- func (mc * mysqlConn ) readInitPacket () ([]byte , error ) {
162+ func (mc * mysqlConn ) readInitPacket () ([]byte , string , error ) {
158163 data , err := mc .readPacket ()
159164 if err != nil {
160165 // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
161166 // in connection initialization we don't risk retrying non-idempotent actions.
162167 if err == ErrInvalidConn {
163- return nil , driver .ErrBadConn
168+ return nil , "" , driver .ErrBadConn
164169 }
165- return nil , err
170+ return nil , "" , err
166171 }
167172
168173 if data [0 ] == iERR {
169- return nil , mc .handleErrorPacket (data )
174+ return nil , "" , mc .handleErrorPacket (data )
170175 }
171176
172177 // protocol version [1 byte]
173178 if data [0 ] < minProtocolVersion {
174- return nil , fmt .Errorf (
179+ return nil , "" , fmt .Errorf (
175180 "unsupported protocol version %d. Version %d or higher is required" ,
176181 data [0 ],
177182 minProtocolVersion ,
@@ -191,13 +196,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
191196 // capability flags (lower 2 bytes) [2 bytes]
192197 mc .flags = clientFlag (binary .LittleEndian .Uint16 (data [pos : pos + 2 ]))
193198 if mc .flags & clientProtocol41 == 0 {
194- return nil , ErrOldProtocol
199+ return nil , "" , ErrOldProtocol
195200 }
196201 if mc .flags & clientSSL == 0 && mc .cfg .tls != nil {
197- return nil , ErrNoTLS
202+ return nil , "" , ErrNoTLS
198203 }
199204 pos += 2
200205
206+ pluginName := ""
201207 if len (data ) > pos {
202208 // character set [1 byte]
203209 // status flags [2 bytes]
@@ -219,6 +225,8 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
219225 // The official Python library uses the fixed length 12
220226 // which seems to work but technically could have a hidden bug.
221227 cipher = append (cipher , data [pos :pos + 12 ]... )
228+ pos += 13
229+ pluginName = string (data [pos : pos + bytes .IndexByte (data [pos :], 0x00 )])
222230
223231 // TODO: Verify string termination
224232 // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
@@ -232,18 +240,22 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
232240 // make a memory safe copy of the cipher slice
233241 var b [20 ]byte
234242 copy (b [:], cipher )
235- return b [:], nil
243+ return b [:], pluginName , nil
236244 }
237245
238246 // make a memory safe copy of the cipher slice
239247 var b [8 ]byte
240248 copy (b [:], cipher )
241- return b [:], nil
249+ return b [:], pluginName , nil
242250}
243251
244252// Client Authentication Packet
245253// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
246- func (mc * mysqlConn ) writeAuthPacket (cipher []byte ) error {
254+ func (mc * mysqlConn ) writeAuthPacket (cipher []byte , pluginName string ) error {
255+ if pluginName != "mysql_native_password" && pluginName != "caching_sha2_password" {
256+ return fmt .Errorf ("unknown authentication plugin name '%s'" , pluginName )
257+ }
258+
247259 // Adjust client flags based on server support
248260 clientFlags := clientProtocol41 |
249261 clientSecureConn |
@@ -268,7 +280,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
268280 }
269281
270282 // User Password
271- scrambleBuff := scramblePassword (cipher , []byte (mc .cfg .Passwd ))
283+ var scrambleBuff []byte
284+ switch pluginName {
285+ case "mysql_native_password" :
286+ scrambleBuff = scramblePassword (cipher , []byte (mc .cfg .Passwd ))
287+ case "caching_sha2_password" :
288+ scrambleBuff = scrambleCachingSha2Password (cipher , []byte (mc .cfg .Passwd ))
289+ }
272290
273291 pktLen := 4 + 4 + 1 + 23 + len (mc .cfg .User ) + 1 + 1 + len (scrambleBuff ) + 21 + 1
274292
@@ -350,7 +368,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
350368 }
351369
352370 // Assume native client during response
353- pos += copy (data [pos :], "mysql_native_password" )
371+ pos += copy (data [pos :], pluginName )
354372 data [pos ] = 0x00
355373
356374 // Send Auth packet
@@ -422,6 +440,38 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
422440 return mc .writePacket (data )
423441}
424442
443+ // Caching sha2 authentication. Public key request and send encrypted password
444+ // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
445+ func (mc * mysqlConn ) writePublicKeyAuthPacket (cipher []byte ) error {
446+ // request public key
447+ data := mc .buf .takeSmallBuffer (4 + 1 )
448+ data [4 ] = cachingSha2PasswordRequestPublicKey
449+ mc .writePacket (data )
450+
451+ data , err := mc .readPacket ()
452+ if err != nil {
453+ return err
454+ }
455+
456+ block , _ := pem .Decode (data [1 :])
457+ pub , err := x509 .ParsePKIXPublicKey (block .Bytes )
458+ if err != nil {
459+ return err
460+ }
461+
462+ plain := make ([]byte , len (mc .cfg .Passwd )+ 1 )
463+ copy (plain , mc .cfg .Passwd )
464+ for i := range plain {
465+ j := i % len (cipher )
466+ plain [i ] ^= cipher [j ]
467+ }
468+ sha1 := sha1 .New ()
469+ enc , _ := rsa .EncryptOAEP (sha1 , rand .Reader , pub .(* rsa.PublicKey ), plain , nil )
470+ data = mc .buf .takeSmallBuffer (4 + len (enc ))
471+ copy (data [4 :], enc )
472+ return mc .writePacket (data )
473+ }
474+
425475/******************************************************************************
426476* Command Packets *
427477******************************************************************************/
@@ -535,6 +585,16 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) {
535585 return nil , err
536586}
537587
588+ func (mc * mysqlConn ) readCachingSha2PasswordAuthResult () (int , error ) {
589+ data , err := mc .readPacket ()
590+ if err == nil {
591+ if data [0 ] != 1 {
592+ return 0 , ErrMalformPkt
593+ }
594+ }
595+ return int (data [1 ]), err
596+ }
597+
538598// Result Set Header Packet
539599// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
540600func (mc * mysqlConn ) readResultSetHeaderPacket () (int , error ) {
0 commit comments