@@ -226,6 +226,8 @@ type Greeting struct {
226226
227227// Opts is a way to configure Connection
228228type Opts struct {
229+ // Auth is an authentication method.
230+ Auth Auth
229231 // Timeout for response to a particular request. The timeout is reset when
230232 // push messages are received. If Timeout is zero, any request can be
231233 // blocked infinitely.
@@ -546,19 +548,40 @@ func (conn *Connection) dial() (err error) {
546548
547549 // Auth.
548550 if opts .User != "" {
549- scr , err := scramble (conn .Greeting .auth , opts .Pass )
550- if err != nil {
551- err = errors .New ("auth: scrambling failure " + err .Error ())
551+ auth := opts .Auth
552+ if opts .Auth == AutoAuth {
553+ if conn .serverProtocolInfo .Auth != AutoAuth {
554+ auth = conn .serverProtocolInfo .Auth
555+ } else {
556+ auth = ChapSha1Auth
557+ }
558+ }
559+
560+ var req Request
561+ if auth == ChapSha1Auth {
562+ salt := conn .Greeting .auth
563+ req , err = newChapSha1AuthRequest (conn .opts .User , salt , opts .Pass )
564+ if err != nil {
565+ return fmt .Errorf ("auth: %w" , err )
566+ }
567+ } else if auth == PapSha256Auth {
568+ if opts .Transport != connTransportSsl {
569+ return errors .New ("auth: forbidden to use " + auth .String () +
570+ " unless SSL is enabled for the connection" )
571+ }
572+ req = newPapSha256AuthRequest (conn .opts .User , opts .Pass )
573+ } else {
552574 connection .Close ()
553- return err
575+ return errors . New ( "auth: " + auth . String ())
554576 }
555- if err = conn .writeAuthRequest (w , scr ); err != nil {
577+
578+ if err = conn .writeRequest (w , req ); err != nil {
556579 connection .Close ()
557- return err
580+ return fmt . Errorf ( "auth: %w" , err )
558581 }
559- if err = conn .readAuthResponse (r ); err != nil {
582+ if _ , err = conn .readResponse (r ); err != nil {
560583 connection .Close ()
561- return err
584+ return fmt . Errorf ( "auth: %w" , err )
562585 }
563586 }
564587
@@ -662,28 +685,6 @@ func (conn *Connection) writeRequest(w *bufio.Writer, req Request) error {
662685 return err
663686}
664687
665- func (conn * Connection ) writeAuthRequest (w * bufio.Writer , scramble []byte ) error {
666- req := newAuthRequest (conn .opts .User , string (scramble ))
667-
668- err := conn .writeRequest (w , req )
669- if err != nil {
670- return fmt .Errorf ("auth: %w" , err )
671- }
672-
673- return nil
674- }
675-
676- func (conn * Connection ) writeIdRequest (w * bufio.Writer , protocolInfo ProtocolInfo ) error {
677- req := NewIdRequest (protocolInfo )
678-
679- err := conn .writeRequest (w , req )
680- if err != nil {
681- return fmt .Errorf ("identify: %w" , err )
682- }
683-
684- return nil
685- }
686-
687688func (conn * Connection ) readResponse (r io.Reader ) (Response , error ) {
688689 respBytes , err := conn .read (r )
689690 if err != nil {
@@ -707,24 +708,6 @@ func (conn *Connection) readResponse(r io.Reader) (Response, error) {
707708 return resp , nil
708709}
709710
710- func (conn * Connection ) readAuthResponse (r io.Reader ) error {
711- _ , err := conn .readResponse (r )
712- if err != nil {
713- return fmt .Errorf ("auth: %w" , err )
714- }
715-
716- return nil
717- }
718-
719- func (conn * Connection ) readIdResponse (r io.Reader ) (Response , error ) {
720- resp , err := conn .readResponse (r )
721- if err != nil {
722- return resp , fmt .Errorf ("identify: %w" , err )
723- }
724-
725- return resp , nil
726- }
727-
728711func (conn * Connection ) createConnection (reconnect bool ) (err error ) {
729712 var reconnects uint
730713 for conn .c == nil && conn .state == connDisconnected {
@@ -1625,19 +1608,20 @@ func checkProtocolInfo(expected ProtocolInfo, actual ProtocolInfo) error {
16251608func (conn * Connection ) identify (w * bufio.Writer , r * bufio.Reader ) error {
16261609 var ok bool
16271610
1628- werr := conn .writeIdRequest (w , clientProtocolInfo )
1611+ req := NewIdRequest (clientProtocolInfo )
1612+ werr := conn .writeRequest (w , req )
16291613 if werr != nil {
1630- return werr
1614+ return fmt . Errorf ( "identify: %w" , werr )
16311615 }
16321616
1633- resp , rerr := conn .readIdResponse (r )
1617+ resp , rerr := conn .readResponse (r )
16341618 if rerr != nil {
16351619 if resp .Code == ErrUnknownRequestType {
16361620 // IPROTO_ID requests are not supported by server.
16371621 return nil
16381622 }
16391623
1640- return rerr
1624+ return fmt . Errorf ( "identify: %w" , rerr )
16411625 }
16421626
16431627 if len (resp .Data ) == 0 {
@@ -1664,5 +1648,7 @@ func (conn *Connection) ServerProtocolInfo() ProtocolInfo {
16641648// supported by Go connection client.
16651649// Since 1.10.0
16661650func (conn * Connection ) ClientProtocolInfo () ProtocolInfo {
1667- return clientProtocolInfo .Clone ()
1651+ info := clientProtocolInfo .Clone ()
1652+ info .Auth = conn .opts .Auth
1653+ return info
16681654}
0 commit comments